1use std::any::Any;
16use std::collections::HashMap;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use datafusion::common::stats::Precision;
22use datafusion::common::{DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
23use datafusion::error::DataFusionError;
24use datafusion::execution::context::TaskContext;
25use datafusion::logical_expr::{EmptyRelation, LogicalPlan, UserDefinedLogicalNodeCore};
26use datafusion::physical_expr::EquivalenceProperties;
27use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use datafusion::physical_plan::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
30 RecordBatchStream, SendableRecordBatchStream,
31};
32use datafusion::prelude::Expr;
33use datafusion::sql::TableReference;
34use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray};
35use datatypes::arrow::compute::{cast_with_options, concat_batches, CastOptions};
36use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
37use datatypes::arrow::record_batch::RecordBatch;
38use futures::{ready, Stream, StreamExt};
39use greptime_proto::substrait_extension as pb;
40use prost::Message;
41use snafu::ResultExt;
42
43use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, DeserializeSnafu, Result};
44use crate::extension_plan::Millisecond;
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct ScalarCalculate {
53 start: Millisecond,
54 end: Millisecond,
55 interval: Millisecond,
56
57 time_index: String,
58 tag_columns: Vec<String>,
59 field_column: String,
60 input: LogicalPlan,
61 output_schema: DFSchemaRef,
62}
63
64impl ScalarCalculate {
65 #[allow(clippy::too_many_arguments)]
67 pub fn new(
68 start: Millisecond,
69 end: Millisecond,
70 interval: Millisecond,
71 input: LogicalPlan,
72 time_index: &str,
73 tag_colunms: &[String],
74 field_column: &str,
75 table_name: Option<&str>,
76 ) -> Result<Self> {
77 let input_schema = input.schema();
78 let Ok(ts_field) = input_schema
79 .field_with_unqualified_name(time_index)
80 .cloned()
81 else {
82 return ColumnNotFoundSnafu { col: time_index }.fail();
83 };
84 let val_field = Field::new(format!("scalar({})", field_column), DataType::Float64, true);
85 let qualifier = table_name.map(TableReference::bare);
86 let schema = DFSchema::new_with_metadata(
87 vec![
88 (qualifier.clone(), Arc::new(ts_field)),
89 (qualifier, Arc::new(val_field)),
90 ],
91 input_schema.metadata().clone(),
92 )
93 .context(DataFusionPlanningSnafu)?;
94
95 Ok(Self {
96 start,
97 end,
98 interval,
99 time_index: time_index.to_string(),
100 tag_columns: tag_colunms.to_vec(),
101 field_column: field_column.to_string(),
102 input,
103 output_schema: Arc::new(schema),
104 })
105 }
106
107 pub const fn name() -> &'static str {
109 "ScalarCalculate"
110 }
111
112 pub fn to_execution_plan(
114 &self,
115 exec_input: Arc<dyn ExecutionPlan>,
116 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
117 let fields: Vec<_> = self
118 .output_schema
119 .fields()
120 .iter()
121 .map(|field| Field::new(field.name(), field.data_type().clone(), field.is_nullable()))
122 .collect();
123 let input_schema = exec_input.schema();
124 let ts_index = input_schema
125 .index_of(&self.time_index)
126 .map_err(|e| DataFusionError::ArrowError(e, None))?;
127 let val_index = input_schema
128 .index_of(&self.field_column)
129 .map_err(|e| DataFusionError::ArrowError(e, None))?;
130 let schema = Arc::new(Schema::new(fields));
131 let properties = exec_input.properties();
132 let properties = PlanProperties::new(
133 EquivalenceProperties::new(schema.clone()),
134 Partitioning::UnknownPartitioning(1),
135 properties.emission_type,
136 properties.boundedness,
137 );
138 Ok(Arc::new(ScalarCalculateExec {
139 start: self.start,
140 end: self.end,
141 interval: self.interval,
142 schema,
143 input: exec_input,
144 project_index: (ts_index, val_index),
145 tag_columns: self.tag_columns.clone(),
146 metric: ExecutionPlanMetricsSet::new(),
147 properties,
148 }))
149 }
150
151 pub fn serialize(&self) -> Vec<u8> {
152 pb::ScalarCalculate {
153 start: self.start,
154 end: self.end,
155 interval: self.interval,
156 time_index: self.time_index.clone(),
157 tag_columns: self.tag_columns.clone(),
158 field_column: self.field_column.clone(),
159 }
160 .encode_to_vec()
161 }
162
163 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
164 let pb_scalar_calculate = pb::ScalarCalculate::decode(bytes).context(DeserializeSnafu)?;
165 let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
166 produce_one_row: false,
167 schema: Arc::new(DFSchema::empty()),
168 });
169 let ts_field = Field::new(
171 &pb_scalar_calculate.time_index,
172 DataType::Timestamp(TimeUnit::Millisecond, None),
173 true,
174 );
175 let val_field = Field::new(
176 format!("scalar({})", pb_scalar_calculate.field_column),
177 DataType::Float64,
178 true,
179 );
180 let schema = DFSchema::new_with_metadata(
182 vec![(None, Arc::new(ts_field)), (None, Arc::new(val_field))],
183 HashMap::new(),
184 )
185 .context(DataFusionPlanningSnafu)?;
186
187 Ok(Self {
188 start: pb_scalar_calculate.start,
189 end: pb_scalar_calculate.end,
190 interval: pb_scalar_calculate.interval,
191 time_index: pb_scalar_calculate.time_index,
192 tag_columns: pb_scalar_calculate.tag_columns,
193 field_column: pb_scalar_calculate.field_column,
194 output_schema: Arc::new(schema),
195 input: placeholder_plan,
196 })
197 }
198}
199
200impl PartialOrd for ScalarCalculate {
201 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
202 match self.start.partial_cmp(&other.start) {
204 Some(core::cmp::Ordering::Equal) => {}
205 ord => return ord,
206 }
207 match self.end.partial_cmp(&other.end) {
208 Some(core::cmp::Ordering::Equal) => {}
209 ord => return ord,
210 }
211 match self.interval.partial_cmp(&other.interval) {
212 Some(core::cmp::Ordering::Equal) => {}
213 ord => return ord,
214 }
215 match self.time_index.partial_cmp(&other.time_index) {
216 Some(core::cmp::Ordering::Equal) => {}
217 ord => return ord,
218 }
219 match self.tag_columns.partial_cmp(&other.tag_columns) {
220 Some(core::cmp::Ordering::Equal) => {}
221 ord => return ord,
222 }
223 match self.field_column.partial_cmp(&other.field_column) {
224 Some(core::cmp::Ordering::Equal) => {}
225 ord => return ord,
226 }
227 self.input.partial_cmp(&other.input)
228 }
229}
230
231impl UserDefinedLogicalNodeCore for ScalarCalculate {
232 fn name(&self) -> &str {
233 Self::name()
234 }
235
236 fn inputs(&self) -> Vec<&LogicalPlan> {
237 vec![&self.input]
238 }
239
240 fn schema(&self) -> &DFSchemaRef {
241 &self.output_schema
242 }
243
244 fn expressions(&self) -> Vec<Expr> {
245 vec![]
246 }
247
248 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
249 write!(f, "ScalarCalculate: tags={:?}", self.tag_columns)
250 }
251
252 fn with_exprs_and_inputs(
253 &self,
254 exprs: Vec<Expr>,
255 inputs: Vec<LogicalPlan>,
256 ) -> DataFusionResult<Self> {
257 if !exprs.is_empty() {
258 return Err(DataFusionError::Internal(
259 "ScalarCalculate should not have any expressions".to_string(),
260 ));
261 }
262 Ok(ScalarCalculate {
263 start: self.start,
264 end: self.end,
265 interval: self.interval,
266 time_index: self.time_index.clone(),
267 tag_columns: self.tag_columns.clone(),
268 field_column: self.field_column.clone(),
269 input: inputs.into_iter().next().unwrap(),
270 output_schema: self.output_schema.clone(),
271 })
272 }
273}
274
275#[derive(Debug, Clone)]
276struct ScalarCalculateExec {
277 start: Millisecond,
278 end: Millisecond,
279 interval: Millisecond,
280 schema: SchemaRef,
281 project_index: (usize, usize),
282 input: Arc<dyn ExecutionPlan>,
283 tag_columns: Vec<String>,
284 metric: ExecutionPlanMetricsSet,
285 properties: PlanProperties,
286}
287
288impl ExecutionPlan for ScalarCalculateExec {
289 fn as_any(&self) -> &dyn Any {
290 self
291 }
292
293 fn schema(&self) -> SchemaRef {
294 self.schema.clone()
295 }
296
297 fn properties(&self) -> &PlanProperties {
298 &self.properties
299 }
300
301 fn maintains_input_order(&self) -> Vec<bool> {
302 vec![true; self.children().len()]
303 }
304
305 fn required_input_distribution(&self) -> Vec<Distribution> {
306 vec![Distribution::SinglePartition]
307 }
308
309 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
310 vec![&self.input]
311 }
312
313 fn with_new_children(
314 self: Arc<Self>,
315 children: Vec<Arc<dyn ExecutionPlan>>,
316 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
317 Ok(Arc::new(ScalarCalculateExec {
318 start: self.start,
319 end: self.end,
320 interval: self.interval,
321 schema: self.schema.clone(),
322 project_index: self.project_index,
323 tag_columns: self.tag_columns.clone(),
324 input: children[0].clone(),
325 metric: self.metric.clone(),
326 properties: self.properties.clone(),
327 }))
328 }
329
330 fn execute(
331 &self,
332 partition: usize,
333 context: Arc<TaskContext>,
334 ) -> DataFusionResult<SendableRecordBatchStream> {
335 let baseline_metric = BaselineMetrics::new(&self.metric, partition);
336 let input = self.input.execute(partition, context)?;
337 let schema = input.schema();
338 let tag_indices = self
339 .tag_columns
340 .iter()
341 .map(|tag| {
342 schema
343 .column_with_name(tag)
344 .unwrap_or_else(|| panic!("tag column not found {tag}"))
345 .0
346 })
347 .collect();
348
349 Ok(Box::pin(ScalarCalculateStream {
350 start: self.start,
351 end: self.end,
352 interval: self.interval,
353 schema: self.schema.clone(),
354 project_index: self.project_index,
355 metric: baseline_metric,
356 tag_indices,
357 input,
358 have_multi_series: false,
359 done: false,
360 batch: None,
361 tag_value: None,
362 }))
363 }
364
365 fn metrics(&self) -> Option<MetricsSet> {
366 Some(self.metric.clone_inner())
367 }
368
369 fn statistics(&self) -> DataFusionResult<Statistics> {
370 let input_stats = self.input.statistics()?;
371
372 let estimated_row_num = (self.end - self.start) as f64 / self.interval as f64;
373 let estimated_total_bytes = input_stats
374 .total_byte_size
375 .get_value()
376 .zip(input_stats.num_rows.get_value())
377 .map(|(size, rows)| {
378 Precision::Inexact(((*size as f64 / *rows as f64) * estimated_row_num).floor() as _)
379 })
380 .unwrap_or_default();
381
382 Ok(Statistics {
383 num_rows: Precision::Inexact(estimated_row_num as _),
384 total_byte_size: estimated_total_bytes,
385 column_statistics: Statistics::unknown_column(&self.schema()),
387 })
388 }
389
390 fn name(&self) -> &str {
391 "ScalarCalculateExec"
392 }
393}
394
395impl DisplayAs for ScalarCalculateExec {
396 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
397 match t {
398 DisplayFormatType::Default | DisplayFormatType::Verbose => {
399 write!(f, "ScalarCalculateExec: tags={:?}", self.tag_columns)
400 }
401 }
402 }
403}
404
405struct ScalarCalculateStream {
406 start: Millisecond,
407 end: Millisecond,
408 interval: Millisecond,
409 schema: SchemaRef,
410 input: SendableRecordBatchStream,
411 metric: BaselineMetrics,
412 tag_indices: Vec<usize>,
413 project_index: (usize, usize),
415 have_multi_series: bool,
416 done: bool,
417 batch: Option<RecordBatch>,
418 tag_value: Option<Vec<String>>,
419}
420
421impl RecordBatchStream for ScalarCalculateStream {
422 fn schema(&self) -> SchemaRef {
423 self.schema.clone()
424 }
425}
426
427impl ScalarCalculateStream {
428 fn update_batch(&mut self, batch: RecordBatch) -> DataFusionResult<()> {
429 let _timer = self.metric.elapsed_compute();
430 if self.have_multi_series || batch.num_rows() == 0 {
432 return Ok(());
433 }
434 if self.tag_indices.is_empty() {
436 self.append_batch(batch)?;
437 return Ok(());
438 }
439 let all_same = |val: Option<&str>, array: &StringArray| -> bool {
440 if let Some(v) = val {
441 array.iter().all(|s| s == Some(v))
442 } else {
443 array.is_empty() || array.iter().skip(1).all(|s| s == Some(array.value(0)))
444 }
445 };
446 let all_tag_columns_same = if let Some(tags) = &self.tag_value {
448 tags.iter()
449 .zip(self.tag_indices.iter())
450 .all(|(value, index)| {
451 let array = batch.column(*index);
452 let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
453 all_same(Some(value), string_array)
454 })
455 } else {
456 let mut tag_values = Vec::with_capacity(self.tag_indices.len());
457 let is_same = self.tag_indices.iter().all(|index| {
458 let array = batch.column(*index);
459 let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
460 tag_values.push(string_array.value(0).to_string());
461 all_same(None, string_array)
462 });
463 self.tag_value = Some(tag_values);
464 is_same
465 };
466 if all_tag_columns_same {
467 self.append_batch(batch)?;
468 } else {
469 self.have_multi_series = true;
470 }
471 Ok(())
472 }
473
474 fn append_batch(&mut self, input_batch: RecordBatch) -> DataFusionResult<()> {
475 let ts_column = input_batch.column(self.project_index.0).clone();
476 let val_column = cast_with_options(
477 input_batch.column(self.project_index.1),
478 &DataType::Float64,
479 &CastOptions::default(),
480 )?;
481 let input_batch = RecordBatch::try_new(self.schema.clone(), vec![ts_column, val_column])?;
482 if let Some(batch) = &self.batch {
483 self.batch = Some(concat_batches(&self.schema, vec![batch, &input_batch])?);
484 } else {
485 self.batch = Some(input_batch);
486 }
487 Ok(())
488 }
489}
490
491impl Stream for ScalarCalculateStream {
492 type Item = DataFusionResult<RecordBatch>;
493
494 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
495 loop {
496 if self.done {
497 return Poll::Ready(None);
498 }
499 match ready!(self.input.poll_next_unpin(cx)) {
500 Some(Ok(batch)) => {
501 self.update_batch(batch)?;
502 }
503 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
505 None => {
507 self.done = true;
508 return match self.batch.take() {
509 Some(batch) if !self.have_multi_series => {
510 self.metric.record_output(batch.num_rows());
511 Poll::Ready(Some(Ok(batch)))
512 }
513 _ => {
514 let time_array = (self.start..=self.end)
515 .step_by(self.interval as _)
516 .collect::<Vec<_>>();
517 let nums = time_array.len();
518 let nan_batch = RecordBatch::try_new(
519 self.schema.clone(),
520 vec![
521 Arc::new(TimestampMillisecondArray::from(time_array)),
522 Arc::new(Float64Array::from(vec![f64::NAN; nums])),
523 ],
524 )?;
525 self.metric.record_output(nan_batch.num_rows());
526 Poll::Ready(Some(Ok(nan_batch)))
527 }
528 };
529 }
530 };
531 }
532 }
533}
534
535#[cfg(test)]
536mod test {
537 use datafusion::arrow::datatypes::{DataType, Field, Schema};
538 use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
539 use datafusion::physical_plan::memory::MemoryExec;
540 use datafusion::prelude::SessionContext;
541 use datatypes::arrow::array::{Float64Array, TimestampMillisecondArray};
542 use datatypes::arrow::datatypes::TimeUnit;
543
544 use super::*;
545
546 fn prepare_test_data(series: Vec<RecordBatch>) -> MemoryExec {
547 let schema = Arc::new(Schema::new(vec![
548 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
549 Field::new("tag1", DataType::Utf8, true),
550 Field::new("tag2", DataType::Utf8, true),
551 Field::new("val", DataType::Float64, true),
552 ]));
553 MemoryExec::try_new(&[series], schema, None).unwrap()
554 }
555
556 async fn run_test(series: Vec<RecordBatch>, expected: &str) {
557 let memory_exec = Arc::new(prepare_test_data(series));
558 let schema = Arc::new(Schema::new(vec![
559 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
560 Field::new("val", DataType::Float64, true),
561 ]));
562 let properties = PlanProperties::new(
563 EquivalenceProperties::new(schema.clone()),
564 Partitioning::UnknownPartitioning(1),
565 EmissionType::Incremental,
566 Boundedness::Bounded,
567 );
568 let scalar_exec = Arc::new(ScalarCalculateExec {
569 start: 0,
570 end: 15_000,
571 interval: 5000,
572 tag_columns: vec!["tag1".to_string(), "tag2".to_string()],
573 input: memory_exec,
574 schema,
575 project_index: (0, 3),
576 metric: ExecutionPlanMetricsSet::new(),
577 properties,
578 });
579 let session_context = SessionContext::default();
580 let result = datafusion::physical_plan::collect(scalar_exec, session_context.task_ctx())
581 .await
582 .unwrap();
583 let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
584 .unwrap()
585 .to_string();
586 assert_eq!(result_literal, expected);
587 }
588
589 #[tokio::test]
590 async fn same_series() {
591 let schema = Arc::new(Schema::new(vec![
592 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
593 Field::new("tag1", DataType::Utf8, true),
594 Field::new("tag2", DataType::Utf8, true),
595 Field::new("val", DataType::Float64, true),
596 ]));
597 run_test(
598 vec![
599 RecordBatch::try_new(
600 schema.clone(),
601 vec![
602 Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
603 Arc::new(StringArray::from(vec!["foo", "foo"])),
604 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
605 Arc::new(Float64Array::from(vec![1.0, 2.0])),
606 ],
607 )
608 .unwrap(),
609 RecordBatch::try_new(
610 schema,
611 vec![
612 Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
613 Arc::new(StringArray::from(vec!["foo", "foo"])),
614 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
615 Arc::new(Float64Array::from(vec![3.0, 4.0])),
616 ],
617 )
618 .unwrap(),
619 ],
620 "+---------------------+-----+\
621 \n| ts | val |\
622 \n+---------------------+-----+\
623 \n| 1970-01-01T00:00:00 | 1.0 |\
624 \n| 1970-01-01T00:00:05 | 2.0 |\
625 \n| 1970-01-01T00:00:10 | 3.0 |\
626 \n| 1970-01-01T00:00:15 | 4.0 |\
627 \n+---------------------+-----+",
628 )
629 .await
630 }
631
632 #[tokio::test]
633 async fn diff_series() {
634 let schema = Arc::new(Schema::new(vec![
635 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
636 Field::new("tag1", DataType::Utf8, true),
637 Field::new("tag2", DataType::Utf8, true),
638 Field::new("val", DataType::Float64, true),
639 ]));
640 run_test(
641 vec![
642 RecordBatch::try_new(
643 schema.clone(),
644 vec![
645 Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
646 Arc::new(StringArray::from(vec!["foo", "foo"])),
647 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
648 Arc::new(Float64Array::from(vec![1.0, 2.0])),
649 ],
650 )
651 .unwrap(),
652 RecordBatch::try_new(
653 schema,
654 vec![
655 Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
656 Arc::new(StringArray::from(vec!["foo", "foo"])),
657 Arc::new(StringArray::from(vec!["🥺", "😝"])),
658 Arc::new(Float64Array::from(vec![3.0, 4.0])),
659 ],
660 )
661 .unwrap(),
662 ],
663 "+---------------------+-----+\
664 \n| ts | val |\
665 \n+---------------------+-----+\
666 \n| 1970-01-01T00:00:00 | NaN |\
667 \n| 1970-01-01T00:00:05 | NaN |\
668 \n| 1970-01-01T00:00:10 | NaN |\
669 \n| 1970-01-01T00:00:15 | NaN |\
670 \n+---------------------+-----+",
671 )
672 .await
673 }
674
675 #[tokio::test]
676 async fn empty_series() {
677 let schema = Arc::new(Schema::new(vec![
678 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
679 Field::new("tag1", DataType::Utf8, true),
680 Field::new("tag2", DataType::Utf8, true),
681 Field::new("val", DataType::Float64, true),
682 ]));
683 run_test(
684 vec![RecordBatch::try_new(
685 schema,
686 vec![
687 Arc::new(TimestampMillisecondArray::new_null(0)),
688 Arc::new(StringArray::new_null(0)),
689 Arc::new(StringArray::new_null(0)),
690 Arc::new(Float64Array::new_null(0)),
691 ],
692 )
693 .unwrap()],
694 "+---------------------+-----+\
695 \n| ts | val |\
696 \n+---------------------+-----+\
697 \n| 1970-01-01T00:00:00 | NaN |\
698 \n| 1970-01-01T00:00:05 | NaN |\
699 \n| 1970-01-01T00:00:10 | NaN |\
700 \n| 1970-01-01T00:00:15 | NaN |\
701 \n+---------------------+-----+",
702 )
703 .await
704 }
705}