promql/extension_plan/
scalar_calculate.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::HashMap;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use datafusion::common::stats::Precision;
22use datafusion::common::{DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
23use datafusion::error::DataFusionError;
24use datafusion::execution::context::TaskContext;
25use datafusion::logical_expr::{EmptyRelation, LogicalPlan, UserDefinedLogicalNodeCore};
26use datafusion::physical_expr::EquivalenceProperties;
27use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use datafusion::physical_plan::{
29    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
30    RecordBatchStream, SendableRecordBatchStream,
31};
32use datafusion::prelude::Expr;
33use datafusion::sql::TableReference;
34use datafusion_expr::col;
35use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray};
36use datatypes::arrow::compute::{CastOptions, cast_with_options, concat_batches};
37use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
38use datatypes::arrow::record_batch::RecordBatch;
39use futures::{Stream, StreamExt, ready};
40use greptime_proto::substrait_extension as pb;
41use prost::Message;
42use snafu::ResultExt;
43
44use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, DeserializeSnafu, Result};
45use crate::extension_plan::{Millisecond, resolve_column_name, serialize_column_index};
46
47/// `ScalarCalculate` is the custom logical plan to calculate
48/// [`scalar`](https://prometheus.io/docs/prometheus/latest/querying/functions/#scalar)
49/// in PromQL, return NaN when have multiple time series.
50///
51/// Return the time series as scalar value when only have one time series.
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct ScalarCalculate {
54    start: Millisecond,
55    end: Millisecond,
56    interval: Millisecond,
57
58    time_index: String,
59    tag_columns: Vec<String>,
60    field_column: String,
61    input: LogicalPlan,
62    output_schema: DFSchemaRef,
63    unfix: Option<UnfixIndices>,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
67struct UnfixIndices {
68    pub time_index_idx: u64,
69    pub tag_column_indices: Vec<u64>,
70    pub field_column_idx: u64,
71}
72
73impl ScalarCalculate {
74    /// create a new `ScalarCalculate` plan
75    #[allow(clippy::too_many_arguments)]
76    pub fn new(
77        start: Millisecond,
78        end: Millisecond,
79        interval: Millisecond,
80        input: LogicalPlan,
81        time_index: &str,
82        tag_columns: &[String],
83        field_column: &str,
84        table_name: Option<&str>,
85    ) -> Result<Self> {
86        let input_schema = input.schema();
87        let Ok(ts_field) = input_schema
88            .field_with_unqualified_name(time_index)
89            .cloned()
90        else {
91            return ColumnNotFoundSnafu { col: time_index }.fail();
92        };
93        let val_field = Field::new(format!("scalar({})", field_column), DataType::Float64, true);
94        let qualifier = table_name.map(TableReference::bare);
95        let schema = DFSchema::new_with_metadata(
96            vec![
97                (qualifier.clone(), Arc::new(ts_field)),
98                (qualifier, Arc::new(val_field)),
99            ],
100            input_schema.metadata().clone(),
101        )
102        .context(DataFusionPlanningSnafu)?;
103
104        Ok(Self {
105            start,
106            end,
107            interval,
108            time_index: time_index.to_string(),
109            tag_columns: tag_columns.to_vec(),
110            field_column: field_column.to_string(),
111            input,
112            output_schema: Arc::new(schema),
113            unfix: None,
114        })
115    }
116
117    /// The name of this custom plan
118    pub const fn name() -> &'static str {
119        "ScalarCalculate"
120    }
121
122    /// Create a new execution plan from ScalarCalculate
123    pub fn to_execution_plan(
124        &self,
125        exec_input: Arc<dyn ExecutionPlan>,
126    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
127        let fields: Vec<_> = self
128            .output_schema
129            .fields()
130            .iter()
131            .map(|field| Field::new(field.name(), field.data_type().clone(), field.is_nullable()))
132            .collect();
133        let input_schema = exec_input.schema();
134        let ts_index = input_schema
135            .index_of(&self.time_index)
136            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
137        let val_index = input_schema
138            .index_of(&self.field_column)
139            .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
140        let schema = Arc::new(Schema::new(fields));
141        let properties = exec_input.properties();
142        let properties = PlanProperties::new(
143            EquivalenceProperties::new(schema.clone()),
144            Partitioning::UnknownPartitioning(1),
145            properties.emission_type,
146            properties.boundedness,
147        );
148        Ok(Arc::new(ScalarCalculateExec {
149            start: self.start,
150            end: self.end,
151            interval: self.interval,
152            schema,
153            input: exec_input,
154            project_index: (ts_index, val_index),
155            tag_columns: self.tag_columns.clone(),
156            metric: ExecutionPlanMetricsSet::new(),
157            properties,
158        }))
159    }
160
161    pub fn serialize(&self) -> Vec<u8> {
162        let time_index_idx = serialize_column_index(self.input.schema(), &self.time_index);
163
164        let tag_column_indices = self
165            .tag_columns
166            .iter()
167            .map(|name| serialize_column_index(self.input.schema(), name))
168            .collect::<Vec<u64>>();
169
170        let field_column_idx = serialize_column_index(self.input.schema(), &self.field_column);
171
172        pb::ScalarCalculate {
173            start: self.start,
174            end: self.end,
175            interval: self.interval,
176            time_index_idx,
177            tag_column_indices,
178            field_column_idx,
179            ..Default::default()
180        }
181        .encode_to_vec()
182    }
183
184    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
185        let pb_scalar_calculate = pb::ScalarCalculate::decode(bytes).context(DeserializeSnafu)?;
186        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
187            produce_one_row: false,
188            schema: Arc::new(DFSchema::empty()),
189        });
190
191        let unfix = UnfixIndices {
192            time_index_idx: pb_scalar_calculate.time_index_idx,
193            tag_column_indices: pb_scalar_calculate.tag_column_indices.clone(),
194            field_column_idx: pb_scalar_calculate.field_column_idx,
195        };
196
197        // TODO(Taylor-lagrange): Supports timestamps of different precisions
198        let ts_field = Field::new(
199            "placeholder_time_index",
200            DataType::Timestamp(TimeUnit::Millisecond, None),
201            true,
202        );
203        let val_field = Field::new("placeholder_field", DataType::Float64, true);
204        // TODO(Taylor-lagrange): missing tablename in pb
205        let schema = DFSchema::new_with_metadata(
206            vec![(None, Arc::new(ts_field)), (None, Arc::new(val_field))],
207            HashMap::new(),
208        )
209        .context(DataFusionPlanningSnafu)?;
210
211        Ok(Self {
212            start: pb_scalar_calculate.start,
213            end: pb_scalar_calculate.end,
214            interval: pb_scalar_calculate.interval,
215            time_index: String::new(),
216            tag_columns: Vec::new(),
217            field_column: String::new(),
218            output_schema: Arc::new(schema),
219            input: placeholder_plan,
220            unfix: Some(unfix),
221        })
222    }
223}
224
225impl PartialOrd for ScalarCalculate {
226    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
227        // Compare fields in order excluding output_schema
228        match self.start.partial_cmp(&other.start) {
229            Some(core::cmp::Ordering::Equal) => {}
230            ord => return ord,
231        }
232        match self.end.partial_cmp(&other.end) {
233            Some(core::cmp::Ordering::Equal) => {}
234            ord => return ord,
235        }
236        match self.interval.partial_cmp(&other.interval) {
237            Some(core::cmp::Ordering::Equal) => {}
238            ord => return ord,
239        }
240        match self.time_index.partial_cmp(&other.time_index) {
241            Some(core::cmp::Ordering::Equal) => {}
242            ord => return ord,
243        }
244        match self.tag_columns.partial_cmp(&other.tag_columns) {
245            Some(core::cmp::Ordering::Equal) => {}
246            ord => return ord,
247        }
248        match self.field_column.partial_cmp(&other.field_column) {
249            Some(core::cmp::Ordering::Equal) => {}
250            ord => return ord,
251        }
252        self.input.partial_cmp(&other.input)
253    }
254}
255
256impl UserDefinedLogicalNodeCore for ScalarCalculate {
257    fn name(&self) -> &str {
258        Self::name()
259    }
260
261    fn inputs(&self) -> Vec<&LogicalPlan> {
262        vec![&self.input]
263    }
264
265    fn schema(&self) -> &DFSchemaRef {
266        &self.output_schema
267    }
268
269    fn expressions(&self) -> Vec<Expr> {
270        if self.unfix.is_some() {
271            return vec![];
272        }
273
274        self.tag_columns
275            .iter()
276            .map(col)
277            .chain(std::iter::once(col(&self.time_index)))
278            .chain(std::iter::once(col(&self.field_column)))
279            .collect()
280    }
281
282    fn necessary_children_exprs(&self, _output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
283        if self.unfix.is_some() {
284            return None;
285        }
286
287        let input_schema = self.input.schema();
288        let time_index_idx = input_schema.index_of_column_by_name(None, &self.time_index)?;
289        let field_column_idx = input_schema.index_of_column_by_name(None, &self.field_column)?;
290
291        let mut required = Vec::with_capacity(2 + self.tag_columns.len());
292        required.extend([time_index_idx, field_column_idx]);
293        for tag in &self.tag_columns {
294            required.push(input_schema.index_of_column_by_name(None, tag)?);
295        }
296
297        required.sort_unstable();
298        required.dedup();
299        Some(vec![required])
300    }
301
302    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303        write!(f, "ScalarCalculate: tags={:?}", self.tag_columns)
304    }
305
306    fn with_exprs_and_inputs(
307        &self,
308        _exprs: Vec<Expr>,
309        inputs: Vec<LogicalPlan>,
310    ) -> DataFusionResult<Self> {
311        let input: LogicalPlan = inputs.into_iter().next().unwrap();
312        let input_schema = input.schema();
313
314        if let Some(unfix) = &self.unfix {
315            // transform indices to names
316            let time_index = resolve_column_name(
317                unfix.time_index_idx,
318                input_schema,
319                "ScalarCalculate",
320                "time index",
321            )?;
322
323            let tag_columns = unfix
324                .tag_column_indices
325                .iter()
326                .map(|idx| resolve_column_name(*idx, input_schema, "ScalarCalculate", "tag"))
327                .collect::<DataFusionResult<Vec<String>>>()?;
328
329            let field_column = resolve_column_name(
330                unfix.field_column_idx,
331                input_schema,
332                "ScalarCalculate",
333                "field",
334            )?;
335
336            // Recreate output schema with actual field names
337            let ts_field = Field::new(
338                &time_index,
339                DataType::Timestamp(TimeUnit::Millisecond, None),
340                true,
341            );
342            let val_field =
343                Field::new(format!("scalar({})", field_column), DataType::Float64, true);
344            let schema = DFSchema::new_with_metadata(
345                vec![(None, Arc::new(ts_field)), (None, Arc::new(val_field))],
346                HashMap::new(),
347            )
348            .context(DataFusionPlanningSnafu)?;
349
350            Ok(ScalarCalculate {
351                start: self.start,
352                end: self.end,
353                interval: self.interval,
354                time_index,
355                tag_columns,
356                field_column,
357                input,
358                output_schema: Arc::new(schema),
359                unfix: None,
360            })
361        } else {
362            Ok(ScalarCalculate {
363                start: self.start,
364                end: self.end,
365                interval: self.interval,
366                time_index: self.time_index.clone(),
367                tag_columns: self.tag_columns.clone(),
368                field_column: self.field_column.clone(),
369                input,
370                output_schema: self.output_schema.clone(),
371                unfix: None,
372            })
373        }
374    }
375}
376
377#[derive(Debug, Clone)]
378struct ScalarCalculateExec {
379    start: Millisecond,
380    end: Millisecond,
381    interval: Millisecond,
382    schema: SchemaRef,
383    project_index: (usize, usize),
384    input: Arc<dyn ExecutionPlan>,
385    tag_columns: Vec<String>,
386    metric: ExecutionPlanMetricsSet,
387    properties: PlanProperties,
388}
389
390impl ExecutionPlan for ScalarCalculateExec {
391    fn as_any(&self) -> &dyn Any {
392        self
393    }
394
395    fn schema(&self) -> SchemaRef {
396        self.schema.clone()
397    }
398
399    fn properties(&self) -> &PlanProperties {
400        &self.properties
401    }
402
403    fn maintains_input_order(&self) -> Vec<bool> {
404        vec![true; self.children().len()]
405    }
406
407    fn required_input_distribution(&self) -> Vec<Distribution> {
408        vec![Distribution::SinglePartition]
409    }
410
411    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
412        vec![&self.input]
413    }
414
415    fn with_new_children(
416        self: Arc<Self>,
417        children: Vec<Arc<dyn ExecutionPlan>>,
418    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
419        Ok(Arc::new(ScalarCalculateExec {
420            start: self.start,
421            end: self.end,
422            interval: self.interval,
423            schema: self.schema.clone(),
424            project_index: self.project_index,
425            tag_columns: self.tag_columns.clone(),
426            input: children[0].clone(),
427            metric: self.metric.clone(),
428            properties: self.properties.clone(),
429        }))
430    }
431
432    fn execute(
433        &self,
434        partition: usize,
435        context: Arc<TaskContext>,
436    ) -> DataFusionResult<SendableRecordBatchStream> {
437        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
438        let input = self.input.execute(partition, context)?;
439        let schema = input.schema();
440        let tag_indices = self
441            .tag_columns
442            .iter()
443            .map(|tag| {
444                schema
445                    .column_with_name(tag)
446                    .unwrap_or_else(|| panic!("tag column not found {tag}"))
447                    .0
448            })
449            .collect();
450
451        Ok(Box::pin(ScalarCalculateStream {
452            start: self.start,
453            end: self.end,
454            interval: self.interval,
455            schema: self.schema.clone(),
456            project_index: self.project_index,
457            metric: baseline_metric,
458            tag_indices,
459            input,
460            have_multi_series: false,
461            done: false,
462            batch: None,
463            tag_value: None,
464        }))
465    }
466
467    fn metrics(&self) -> Option<MetricsSet> {
468        Some(self.metric.clone_inner())
469    }
470
471    fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
472        let input_stats = self.input.partition_statistics(partition)?;
473
474        let estimated_row_num = (self.end - self.start) as f64 / self.interval as f64;
475        let estimated_total_bytes = input_stats
476            .total_byte_size
477            .get_value()
478            .zip(input_stats.num_rows.get_value())
479            .map(|(size, rows)| {
480                Precision::Inexact(((*size as f64 / *rows as f64) * estimated_row_num).floor() as _)
481            })
482            .unwrap_or_default();
483
484        Ok(Statistics {
485            num_rows: Precision::Inexact(estimated_row_num as _),
486            total_byte_size: estimated_total_bytes,
487            // TODO(ruihang): support this column statistics
488            column_statistics: Statistics::unknown_column(&self.schema()),
489        })
490    }
491
492    fn name(&self) -> &str {
493        "ScalarCalculateExec"
494    }
495}
496
497impl DisplayAs for ScalarCalculateExec {
498    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
499        match t {
500            DisplayFormatType::Default
501            | DisplayFormatType::Verbose
502            | DisplayFormatType::TreeRender => {
503                write!(f, "ScalarCalculateExec: tags={:?}", self.tag_columns)
504            }
505        }
506    }
507}
508
509struct ScalarCalculateStream {
510    start: Millisecond,
511    end: Millisecond,
512    interval: Millisecond,
513    schema: SchemaRef,
514    input: SendableRecordBatchStream,
515    metric: BaselineMetrics,
516    tag_indices: Vec<usize>,
517    /// with format `(ts_index, field_index)`
518    project_index: (usize, usize),
519    have_multi_series: bool,
520    done: bool,
521    batch: Option<RecordBatch>,
522    tag_value: Option<Vec<String>>,
523}
524
525impl RecordBatchStream for ScalarCalculateStream {
526    fn schema(&self) -> SchemaRef {
527        self.schema.clone()
528    }
529}
530
531impl ScalarCalculateStream {
532    fn update_batch(&mut self, batch: RecordBatch) -> DataFusionResult<()> {
533        let _timer = self.metric.elapsed_compute();
534        // if have multi time series or empty batch, scalar will return NaN
535        if self.have_multi_series || batch.num_rows() == 0 {
536            return Ok(());
537        }
538        // fast path: no tag columns means all data belongs to the same series.
539        if self.tag_indices.is_empty() {
540            self.append_batch(batch)?;
541            return Ok(());
542        }
543        let all_same = |val: Option<&str>, array: &StringArray| -> bool {
544            if let Some(v) = val {
545                array.iter().all(|s| s == Some(v))
546            } else {
547                array.is_empty() || array.iter().skip(1).all(|s| s == Some(array.value(0)))
548            }
549        };
550        // assert the entire batch belong to the same series
551        let all_tag_columns_same = if let Some(tags) = &self.tag_value {
552            tags.iter()
553                .zip(self.tag_indices.iter())
554                .all(|(value, index)| {
555                    let array = batch.column(*index);
556                    let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
557                    all_same(Some(value), string_array)
558                })
559        } else {
560            let mut tag_values = Vec::with_capacity(self.tag_indices.len());
561            let is_same = self.tag_indices.iter().all(|index| {
562                let array = batch.column(*index);
563                let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
564                tag_values.push(string_array.value(0).to_string());
565                all_same(None, string_array)
566            });
567            self.tag_value = Some(tag_values);
568            is_same
569        };
570        if all_tag_columns_same {
571            self.append_batch(batch)?;
572        } else {
573            self.have_multi_series = true;
574        }
575        Ok(())
576    }
577
578    fn append_batch(&mut self, input_batch: RecordBatch) -> DataFusionResult<()> {
579        let ts_column = input_batch.column(self.project_index.0).clone();
580        let val_column = cast_with_options(
581            input_batch.column(self.project_index.1),
582            &DataType::Float64,
583            &CastOptions::default(),
584        )?;
585        let input_batch = RecordBatch::try_new(self.schema.clone(), vec![ts_column, val_column])?;
586        if let Some(batch) = &self.batch {
587            self.batch = Some(concat_batches(&self.schema, vec![batch, &input_batch])?);
588        } else {
589            self.batch = Some(input_batch);
590        }
591        Ok(())
592    }
593}
594
595impl Stream for ScalarCalculateStream {
596    type Item = DataFusionResult<RecordBatch>;
597
598    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
599        loop {
600            if self.done {
601                return Poll::Ready(None);
602            }
603            match ready!(self.input.poll_next_unpin(cx)) {
604                Some(Ok(batch)) => {
605                    self.update_batch(batch)?;
606                }
607                // inner had error, return to caller
608                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
609                // inner is done, producing output
610                None => {
611                    self.done = true;
612                    return match self.batch.take() {
613                        Some(batch) if !self.have_multi_series => {
614                            self.metric.record_output(batch.num_rows());
615                            Poll::Ready(Some(Ok(batch)))
616                        }
617                        _ => {
618                            let time_array = (self.start..=self.end)
619                                .step_by(self.interval as _)
620                                .collect::<Vec<_>>();
621                            let nums = time_array.len();
622                            let nan_batch = RecordBatch::try_new(
623                                self.schema.clone(),
624                                vec![
625                                    Arc::new(TimestampMillisecondArray::from(time_array)),
626                                    Arc::new(Float64Array::from(vec![f64::NAN; nums])),
627                                ],
628                            )?;
629                            self.metric.record_output(nan_batch.num_rows());
630                            Poll::Ready(Some(Ok(nan_batch)))
631                        }
632                    };
633                }
634            };
635        }
636    }
637}
638
639#[cfg(test)]
640mod test {
641    use datafusion::arrow::datatypes::{DataType, Field, Schema};
642    use datafusion::datasource::memory::MemorySourceConfig;
643    use datafusion::datasource::source::DataSourceExec;
644    use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
645    use datafusion::prelude::SessionContext;
646    use datatypes::arrow::array::{Float64Array, TimestampMillisecondArray};
647    use datatypes::arrow::datatypes::TimeUnit;
648
649    use super::*;
650
651    fn project_batch(batch: &RecordBatch, indices: &[usize]) -> RecordBatch {
652        let fields = indices
653            .iter()
654            .map(|&idx| batch.schema().field(idx).clone())
655            .collect::<Vec<_>>();
656        let columns = indices
657            .iter()
658            .map(|&idx| batch.column(idx).clone())
659            .collect::<Vec<_>>();
660        let schema = Arc::new(Schema::new(fields));
661        RecordBatch::try_new(schema, columns).unwrap()
662    }
663
664    #[test]
665    fn necessary_children_exprs_preserve_tag_columns() {
666        let schema = Arc::new(Schema::new(vec![
667            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
668            Field::new("tag1", DataType::Utf8, true),
669            Field::new("tag2", DataType::Utf8, true),
670            Field::new("val", DataType::Float64, true),
671            Field::new("extra", DataType::Utf8, true),
672        ]));
673        let schema = Arc::new(DFSchema::try_from(schema).unwrap());
674        let input = LogicalPlan::EmptyRelation(EmptyRelation {
675            produce_one_row: false,
676            schema,
677        });
678        let tag_columns = vec!["tag1".to_string(), "tag2".to_string()];
679        let plan = ScalarCalculate::new(0, 1, 1, input, "ts", &tag_columns, "val", None).unwrap();
680
681        let required = plan.necessary_children_exprs(&[0, 1]).unwrap();
682        assert_eq!(required, vec![vec![0, 1, 2, 3]]);
683    }
684
685    #[tokio::test]
686    async fn pruning_should_keep_tag_columns_for_exec() {
687        let schema = Arc::new(Schema::new(vec![
688            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
689            Field::new("tag1", DataType::Utf8, true),
690            Field::new("tag2", DataType::Utf8, true),
691            Field::new("val", DataType::Float64, true),
692            Field::new("extra", DataType::Utf8, true),
693        ]));
694        let df_schema = Arc::new(DFSchema::try_from(schema.clone()).unwrap());
695        let input = LogicalPlan::EmptyRelation(EmptyRelation {
696            produce_one_row: false,
697            schema: df_schema,
698        });
699        let tag_columns = vec!["tag1".to_string(), "tag2".to_string()];
700        let plan =
701            ScalarCalculate::new(0, 15_000, 5000, input, "ts", &tag_columns, "val", None).unwrap();
702
703        let required = plan.necessary_children_exprs(&[0, 1]).unwrap();
704        let required = &required[0];
705
706        let batch = RecordBatch::try_new(
707            schema,
708            vec![
709                Arc::new(TimestampMillisecondArray::from(vec![
710                    0, 5_000, 10_000, 15_000,
711                ])),
712                Arc::new(StringArray::from(vec!["foo", "foo", "foo", "foo"])),
713                Arc::new(StringArray::from(vec!["bar", "bar", "bar", "bar"])),
714                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
715                Arc::new(StringArray::from(vec!["x", "x", "x", "x"])),
716            ],
717        )
718        .unwrap();
719
720        let projected_batch = project_batch(&batch, required);
721        let projected_schema = projected_batch.schema();
722        let memory_exec = Arc::new(DataSourceExec::new(Arc::new(
723            MemorySourceConfig::try_new(&[vec![projected_batch]], projected_schema, None).unwrap(),
724        )));
725        let scalar_exec = plan.to_execution_plan(memory_exec).unwrap();
726
727        let session_context = SessionContext::default();
728        let result = datafusion::physical_plan::collect(scalar_exec, session_context.task_ctx())
729            .await
730            .unwrap();
731
732        assert_eq!(result.len(), 1);
733        let batch = &result[0];
734        assert_eq!(batch.num_columns(), 2);
735        assert_eq!(batch.num_rows(), 4);
736        assert_eq!(batch.schema().field(0).name(), "ts");
737        assert_eq!(batch.schema().field(1).name(), "scalar(val)");
738
739        let ts = batch
740            .column(0)
741            .as_any()
742            .downcast_ref::<TimestampMillisecondArray>()
743            .unwrap();
744        assert_eq!(ts.values(), &[0i64, 5_000, 10_000, 15_000]);
745
746        let values = batch
747            .column(1)
748            .as_any()
749            .downcast_ref::<Float64Array>()
750            .unwrap();
751        assert_eq!(values.values(), &[1.0f64, 2.0, 3.0, 4.0]);
752    }
753
754    fn prepare_test_data(series: Vec<RecordBatch>) -> DataSourceExec {
755        let schema = Arc::new(Schema::new(vec![
756            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
757            Field::new("tag1", DataType::Utf8, true),
758            Field::new("tag2", DataType::Utf8, true),
759            Field::new("val", DataType::Float64, true),
760        ]));
761        DataSourceExec::new(Arc::new(
762            MemorySourceConfig::try_new(&[series], schema, None).unwrap(),
763        ))
764    }
765
766    async fn run_test(series: Vec<RecordBatch>, expected: &str) {
767        let memory_exec = Arc::new(prepare_test_data(series));
768        let schema = Arc::new(Schema::new(vec![
769            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
770            Field::new("val", DataType::Float64, true),
771        ]));
772        let properties = PlanProperties::new(
773            EquivalenceProperties::new(schema.clone()),
774            Partitioning::UnknownPartitioning(1),
775            EmissionType::Incremental,
776            Boundedness::Bounded,
777        );
778        let scalar_exec = Arc::new(ScalarCalculateExec {
779            start: 0,
780            end: 15_000,
781            interval: 5000,
782            tag_columns: vec!["tag1".to_string(), "tag2".to_string()],
783            input: memory_exec,
784            schema,
785            project_index: (0, 3),
786            metric: ExecutionPlanMetricsSet::new(),
787            properties,
788        });
789        let session_context = SessionContext::default();
790        let result = datafusion::physical_plan::collect(scalar_exec, session_context.task_ctx())
791            .await
792            .unwrap();
793        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
794            .unwrap()
795            .to_string();
796        assert_eq!(result_literal, expected);
797    }
798
799    #[tokio::test]
800    async fn same_series() {
801        let schema = Arc::new(Schema::new(vec![
802            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
803            Field::new("tag1", DataType::Utf8, true),
804            Field::new("tag2", DataType::Utf8, true),
805            Field::new("val", DataType::Float64, true),
806        ]));
807        run_test(
808            vec![
809                RecordBatch::try_new(
810                    schema.clone(),
811                    vec![
812                        Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
813                        Arc::new(StringArray::from(vec!["foo", "foo"])),
814                        Arc::new(StringArray::from(vec!["🥺", "🥺"])),
815                        Arc::new(Float64Array::from(vec![1.0, 2.0])),
816                    ],
817                )
818                .unwrap(),
819                RecordBatch::try_new(
820                    schema,
821                    vec![
822                        Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
823                        Arc::new(StringArray::from(vec!["foo", "foo"])),
824                        Arc::new(StringArray::from(vec!["🥺", "🥺"])),
825                        Arc::new(Float64Array::from(vec![3.0, 4.0])),
826                    ],
827                )
828                .unwrap(),
829            ],
830            "+---------------------+-----+\
831            \n| ts                  | val |\
832            \n+---------------------+-----+\
833            \n| 1970-01-01T00:00:00 | 1.0 |\
834            \n| 1970-01-01T00:00:05 | 2.0 |\
835            \n| 1970-01-01T00:00:10 | 3.0 |\
836            \n| 1970-01-01T00:00:15 | 4.0 |\
837            \n+---------------------+-----+",
838        )
839        .await
840    }
841
842    #[tokio::test]
843    async fn diff_series() {
844        let schema = Arc::new(Schema::new(vec![
845            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
846            Field::new("tag1", DataType::Utf8, true),
847            Field::new("tag2", DataType::Utf8, true),
848            Field::new("val", DataType::Float64, true),
849        ]));
850        run_test(
851            vec![
852                RecordBatch::try_new(
853                    schema.clone(),
854                    vec![
855                        Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
856                        Arc::new(StringArray::from(vec!["foo", "foo"])),
857                        Arc::new(StringArray::from(vec!["🥺", "🥺"])),
858                        Arc::new(Float64Array::from(vec![1.0, 2.0])),
859                    ],
860                )
861                .unwrap(),
862                RecordBatch::try_new(
863                    schema,
864                    vec![
865                        Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
866                        Arc::new(StringArray::from(vec!["foo", "foo"])),
867                        Arc::new(StringArray::from(vec!["🥺", "😝"])),
868                        Arc::new(Float64Array::from(vec![3.0, 4.0])),
869                    ],
870                )
871                .unwrap(),
872            ],
873            "+---------------------+-----+\
874            \n| ts                  | val |\
875            \n+---------------------+-----+\
876            \n| 1970-01-01T00:00:00 | NaN |\
877            \n| 1970-01-01T00:00:05 | NaN |\
878            \n| 1970-01-01T00:00:10 | NaN |\
879            \n| 1970-01-01T00:00:15 | NaN |\
880            \n+---------------------+-----+",
881        )
882        .await
883    }
884
885    #[tokio::test]
886    async fn empty_series() {
887        let schema = Arc::new(Schema::new(vec![
888            Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
889            Field::new("tag1", DataType::Utf8, true),
890            Field::new("tag2", DataType::Utf8, true),
891            Field::new("val", DataType::Float64, true),
892        ]));
893        run_test(
894            vec![
895                RecordBatch::try_new(
896                    schema,
897                    vec![
898                        Arc::new(TimestampMillisecondArray::new_null(0)),
899                        Arc::new(StringArray::new_null(0)),
900                        Arc::new(StringArray::new_null(0)),
901                        Arc::new(Float64Array::new_null(0)),
902                    ],
903                )
904                .unwrap(),
905            ],
906            "+---------------------+-----+\
907            \n| ts                  | val |\
908            \n+---------------------+-----+\
909            \n| 1970-01-01T00:00:00 | NaN |\
910            \n| 1970-01-01T00:00:05 | NaN |\
911            \n| 1970-01-01T00:00:10 | NaN |\
912            \n| 1970-01-01T00:00:15 | NaN |\
913            \n+---------------------+-----+",
914        )
915        .await
916    }
917}