1use std::any::Any;
16use std::collections::HashMap;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21use datafusion::common::stats::Precision;
22use datafusion::common::{DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
23use datafusion::error::DataFusionError;
24use datafusion::execution::context::TaskContext;
25use datafusion::logical_expr::{EmptyRelation, LogicalPlan, UserDefinedLogicalNodeCore};
26use datafusion::physical_expr::EquivalenceProperties;
27use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use datafusion::physical_plan::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
30 RecordBatchStream, SendableRecordBatchStream,
31};
32use datafusion::prelude::Expr;
33use datafusion::sql::TableReference;
34use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray};
35use datatypes::arrow::compute::{cast_with_options, concat_batches, CastOptions};
36use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
37use datatypes::arrow::record_batch::RecordBatch;
38use futures::{ready, Stream, StreamExt};
39use greptime_proto::substrait_extension as pb;
40use prost::Message;
41use snafu::ResultExt;
42
43use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, DeserializeSnafu, Result};
44use crate::extension_plan::Millisecond;
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct ScalarCalculate {
53 start: Millisecond,
54 end: Millisecond,
55 interval: Millisecond,
56
57 time_index: String,
58 tag_columns: Vec<String>,
59 field_column: String,
60 input: LogicalPlan,
61 output_schema: DFSchemaRef,
62}
63
64impl ScalarCalculate {
65 #[allow(clippy::too_many_arguments)]
67 pub fn new(
68 start: Millisecond,
69 end: Millisecond,
70 interval: Millisecond,
71 input: LogicalPlan,
72 time_index: &str,
73 tag_colunms: &[String],
74 field_column: &str,
75 table_name: Option<&str>,
76 ) -> Result<Self> {
77 let input_schema = input.schema();
78 let Ok(ts_field) = input_schema
79 .field_with_unqualified_name(time_index)
80 .cloned()
81 else {
82 return ColumnNotFoundSnafu { col: time_index }.fail();
83 };
84 let val_field = Field::new(format!("scalar({})", field_column), DataType::Float64, true);
85 let qualifier = table_name.map(TableReference::bare);
86 let schema = DFSchema::new_with_metadata(
87 vec![
88 (qualifier.clone(), Arc::new(ts_field)),
89 (qualifier, Arc::new(val_field)),
90 ],
91 input_schema.metadata().clone(),
92 )
93 .context(DataFusionPlanningSnafu)?;
94
95 Ok(Self {
96 start,
97 end,
98 interval,
99 time_index: time_index.to_string(),
100 tag_columns: tag_colunms.to_vec(),
101 field_column: field_column.to_string(),
102 input,
103 output_schema: Arc::new(schema),
104 })
105 }
106
107 pub const fn name() -> &'static str {
109 "ScalarCalculate"
110 }
111
112 pub fn to_execution_plan(
114 &self,
115 exec_input: Arc<dyn ExecutionPlan>,
116 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
117 let fields: Vec<_> = self
118 .output_schema
119 .fields()
120 .iter()
121 .map(|field| Field::new(field.name(), field.data_type().clone(), field.is_nullable()))
122 .collect();
123 let input_schema = exec_input.schema();
124 let ts_index = input_schema
125 .index_of(&self.time_index)
126 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
127 let val_index = input_schema
128 .index_of(&self.field_column)
129 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
130 let schema = Arc::new(Schema::new(fields));
131 let properties = exec_input.properties();
132 let properties = PlanProperties::new(
133 EquivalenceProperties::new(schema.clone()),
134 Partitioning::UnknownPartitioning(1),
135 properties.emission_type,
136 properties.boundedness,
137 );
138 Ok(Arc::new(ScalarCalculateExec {
139 start: self.start,
140 end: self.end,
141 interval: self.interval,
142 schema,
143 input: exec_input,
144 project_index: (ts_index, val_index),
145 tag_columns: self.tag_columns.clone(),
146 metric: ExecutionPlanMetricsSet::new(),
147 properties,
148 }))
149 }
150
151 pub fn serialize(&self) -> Vec<u8> {
152 pb::ScalarCalculate {
153 start: self.start,
154 end: self.end,
155 interval: self.interval,
156 time_index: self.time_index.clone(),
157 tag_columns: self.tag_columns.clone(),
158 field_column: self.field_column.clone(),
159 }
160 .encode_to_vec()
161 }
162
163 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
164 let pb_scalar_calculate = pb::ScalarCalculate::decode(bytes).context(DeserializeSnafu)?;
165 let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
166 produce_one_row: false,
167 schema: Arc::new(DFSchema::empty()),
168 });
169 let ts_field = Field::new(
171 &pb_scalar_calculate.time_index,
172 DataType::Timestamp(TimeUnit::Millisecond, None),
173 true,
174 );
175 let val_field = Field::new(
176 format!("scalar({})", pb_scalar_calculate.field_column),
177 DataType::Float64,
178 true,
179 );
180 let schema = DFSchema::new_with_metadata(
182 vec![(None, Arc::new(ts_field)), (None, Arc::new(val_field))],
183 HashMap::new(),
184 )
185 .context(DataFusionPlanningSnafu)?;
186
187 Ok(Self {
188 start: pb_scalar_calculate.start,
189 end: pb_scalar_calculate.end,
190 interval: pb_scalar_calculate.interval,
191 time_index: pb_scalar_calculate.time_index,
192 tag_columns: pb_scalar_calculate.tag_columns,
193 field_column: pb_scalar_calculate.field_column,
194 output_schema: Arc::new(schema),
195 input: placeholder_plan,
196 })
197 }
198}
199
200impl PartialOrd for ScalarCalculate {
201 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
202 match self.start.partial_cmp(&other.start) {
204 Some(core::cmp::Ordering::Equal) => {}
205 ord => return ord,
206 }
207 match self.end.partial_cmp(&other.end) {
208 Some(core::cmp::Ordering::Equal) => {}
209 ord => return ord,
210 }
211 match self.interval.partial_cmp(&other.interval) {
212 Some(core::cmp::Ordering::Equal) => {}
213 ord => return ord,
214 }
215 match self.time_index.partial_cmp(&other.time_index) {
216 Some(core::cmp::Ordering::Equal) => {}
217 ord => return ord,
218 }
219 match self.tag_columns.partial_cmp(&other.tag_columns) {
220 Some(core::cmp::Ordering::Equal) => {}
221 ord => return ord,
222 }
223 match self.field_column.partial_cmp(&other.field_column) {
224 Some(core::cmp::Ordering::Equal) => {}
225 ord => return ord,
226 }
227 self.input.partial_cmp(&other.input)
228 }
229}
230
231impl UserDefinedLogicalNodeCore for ScalarCalculate {
232 fn name(&self) -> &str {
233 Self::name()
234 }
235
236 fn inputs(&self) -> Vec<&LogicalPlan> {
237 vec![&self.input]
238 }
239
240 fn schema(&self) -> &DFSchemaRef {
241 &self.output_schema
242 }
243
244 fn expressions(&self) -> Vec<Expr> {
245 vec![]
246 }
247
248 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
249 write!(f, "ScalarCalculate: tags={:?}", self.tag_columns)
250 }
251
252 fn with_exprs_and_inputs(
253 &self,
254 exprs: Vec<Expr>,
255 inputs: Vec<LogicalPlan>,
256 ) -> DataFusionResult<Self> {
257 if !exprs.is_empty() {
258 return Err(DataFusionError::Internal(
259 "ScalarCalculate should not have any expressions".to_string(),
260 ));
261 }
262 Ok(ScalarCalculate {
263 start: self.start,
264 end: self.end,
265 interval: self.interval,
266 time_index: self.time_index.clone(),
267 tag_columns: self.tag_columns.clone(),
268 field_column: self.field_column.clone(),
269 input: inputs.into_iter().next().unwrap(),
270 output_schema: self.output_schema.clone(),
271 })
272 }
273}
274
275#[derive(Debug, Clone)]
276struct ScalarCalculateExec {
277 start: Millisecond,
278 end: Millisecond,
279 interval: Millisecond,
280 schema: SchemaRef,
281 project_index: (usize, usize),
282 input: Arc<dyn ExecutionPlan>,
283 tag_columns: Vec<String>,
284 metric: ExecutionPlanMetricsSet,
285 properties: PlanProperties,
286}
287
288impl ExecutionPlan for ScalarCalculateExec {
289 fn as_any(&self) -> &dyn Any {
290 self
291 }
292
293 fn schema(&self) -> SchemaRef {
294 self.schema.clone()
295 }
296
297 fn properties(&self) -> &PlanProperties {
298 &self.properties
299 }
300
301 fn maintains_input_order(&self) -> Vec<bool> {
302 vec![true; self.children().len()]
303 }
304
305 fn required_input_distribution(&self) -> Vec<Distribution> {
306 vec![Distribution::SinglePartition]
307 }
308
309 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
310 vec![&self.input]
311 }
312
313 fn with_new_children(
314 self: Arc<Self>,
315 children: Vec<Arc<dyn ExecutionPlan>>,
316 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
317 Ok(Arc::new(ScalarCalculateExec {
318 start: self.start,
319 end: self.end,
320 interval: self.interval,
321 schema: self.schema.clone(),
322 project_index: self.project_index,
323 tag_columns: self.tag_columns.clone(),
324 input: children[0].clone(),
325 metric: self.metric.clone(),
326 properties: self.properties.clone(),
327 }))
328 }
329
330 fn execute(
331 &self,
332 partition: usize,
333 context: Arc<TaskContext>,
334 ) -> DataFusionResult<SendableRecordBatchStream> {
335 let baseline_metric = BaselineMetrics::new(&self.metric, partition);
336 let input = self.input.execute(partition, context)?;
337 let schema = input.schema();
338 let tag_indices = self
339 .tag_columns
340 .iter()
341 .map(|tag| {
342 schema
343 .column_with_name(tag)
344 .unwrap_or_else(|| panic!("tag column not found {tag}"))
345 .0
346 })
347 .collect();
348
349 Ok(Box::pin(ScalarCalculateStream {
350 start: self.start,
351 end: self.end,
352 interval: self.interval,
353 schema: self.schema.clone(),
354 project_index: self.project_index,
355 metric: baseline_metric,
356 tag_indices,
357 input,
358 have_multi_series: false,
359 done: false,
360 batch: None,
361 tag_value: None,
362 }))
363 }
364
365 fn metrics(&self) -> Option<MetricsSet> {
366 Some(self.metric.clone_inner())
367 }
368
369 fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
370 let input_stats = self.input.partition_statistics(partition)?;
371
372 let estimated_row_num = (self.end - self.start) as f64 / self.interval as f64;
373 let estimated_total_bytes = input_stats
374 .total_byte_size
375 .get_value()
376 .zip(input_stats.num_rows.get_value())
377 .map(|(size, rows)| {
378 Precision::Inexact(((*size as f64 / *rows as f64) * estimated_row_num).floor() as _)
379 })
380 .unwrap_or_default();
381
382 Ok(Statistics {
383 num_rows: Precision::Inexact(estimated_row_num as _),
384 total_byte_size: estimated_total_bytes,
385 column_statistics: Statistics::unknown_column(&self.schema()),
387 })
388 }
389
390 fn name(&self) -> &str {
391 "ScalarCalculateExec"
392 }
393}
394
395impl DisplayAs for ScalarCalculateExec {
396 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
397 match t {
398 DisplayFormatType::Default
399 | DisplayFormatType::Verbose
400 | DisplayFormatType::TreeRender => {
401 write!(f, "ScalarCalculateExec: tags={:?}", self.tag_columns)
402 }
403 }
404 }
405}
406
407struct ScalarCalculateStream {
408 start: Millisecond,
409 end: Millisecond,
410 interval: Millisecond,
411 schema: SchemaRef,
412 input: SendableRecordBatchStream,
413 metric: BaselineMetrics,
414 tag_indices: Vec<usize>,
415 project_index: (usize, usize),
417 have_multi_series: bool,
418 done: bool,
419 batch: Option<RecordBatch>,
420 tag_value: Option<Vec<String>>,
421}
422
423impl RecordBatchStream for ScalarCalculateStream {
424 fn schema(&self) -> SchemaRef {
425 self.schema.clone()
426 }
427}
428
429impl ScalarCalculateStream {
430 fn update_batch(&mut self, batch: RecordBatch) -> DataFusionResult<()> {
431 let _timer = self.metric.elapsed_compute();
432 if self.have_multi_series || batch.num_rows() == 0 {
434 return Ok(());
435 }
436 if self.tag_indices.is_empty() {
438 self.append_batch(batch)?;
439 return Ok(());
440 }
441 let all_same = |val: Option<&str>, array: &StringArray| -> bool {
442 if let Some(v) = val {
443 array.iter().all(|s| s == Some(v))
444 } else {
445 array.is_empty() || array.iter().skip(1).all(|s| s == Some(array.value(0)))
446 }
447 };
448 let all_tag_columns_same = if let Some(tags) = &self.tag_value {
450 tags.iter()
451 .zip(self.tag_indices.iter())
452 .all(|(value, index)| {
453 let array = batch.column(*index);
454 let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
455 all_same(Some(value), string_array)
456 })
457 } else {
458 let mut tag_values = Vec::with_capacity(self.tag_indices.len());
459 let is_same = self.tag_indices.iter().all(|index| {
460 let array = batch.column(*index);
461 let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
462 tag_values.push(string_array.value(0).to_string());
463 all_same(None, string_array)
464 });
465 self.tag_value = Some(tag_values);
466 is_same
467 };
468 if all_tag_columns_same {
469 self.append_batch(batch)?;
470 } else {
471 self.have_multi_series = true;
472 }
473 Ok(())
474 }
475
476 fn append_batch(&mut self, input_batch: RecordBatch) -> DataFusionResult<()> {
477 let ts_column = input_batch.column(self.project_index.0).clone();
478 let val_column = cast_with_options(
479 input_batch.column(self.project_index.1),
480 &DataType::Float64,
481 &CastOptions::default(),
482 )?;
483 let input_batch = RecordBatch::try_new(self.schema.clone(), vec![ts_column, val_column])?;
484 if let Some(batch) = &self.batch {
485 self.batch = Some(concat_batches(&self.schema, vec![batch, &input_batch])?);
486 } else {
487 self.batch = Some(input_batch);
488 }
489 Ok(())
490 }
491}
492
493impl Stream for ScalarCalculateStream {
494 type Item = DataFusionResult<RecordBatch>;
495
496 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
497 loop {
498 if self.done {
499 return Poll::Ready(None);
500 }
501 match ready!(self.input.poll_next_unpin(cx)) {
502 Some(Ok(batch)) => {
503 self.update_batch(batch)?;
504 }
505 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
507 None => {
509 self.done = true;
510 return match self.batch.take() {
511 Some(batch) if !self.have_multi_series => {
512 self.metric.record_output(batch.num_rows());
513 Poll::Ready(Some(Ok(batch)))
514 }
515 _ => {
516 let time_array = (self.start..=self.end)
517 .step_by(self.interval as _)
518 .collect::<Vec<_>>();
519 let nums = time_array.len();
520 let nan_batch = RecordBatch::try_new(
521 self.schema.clone(),
522 vec![
523 Arc::new(TimestampMillisecondArray::from(time_array)),
524 Arc::new(Float64Array::from(vec![f64::NAN; nums])),
525 ],
526 )?;
527 self.metric.record_output(nan_batch.num_rows());
528 Poll::Ready(Some(Ok(nan_batch)))
529 }
530 };
531 }
532 };
533 }
534 }
535}
536
537#[cfg(test)]
538mod test {
539 use datafusion::arrow::datatypes::{DataType, Field, Schema};
540 use datafusion::datasource::memory::MemorySourceConfig;
541 use datafusion::datasource::source::DataSourceExec;
542 use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
543 use datafusion::prelude::SessionContext;
544 use datatypes::arrow::array::{Float64Array, TimestampMillisecondArray};
545 use datatypes::arrow::datatypes::TimeUnit;
546
547 use super::*;
548
549 fn prepare_test_data(series: Vec<RecordBatch>) -> DataSourceExec {
550 let schema = Arc::new(Schema::new(vec![
551 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
552 Field::new("tag1", DataType::Utf8, true),
553 Field::new("tag2", DataType::Utf8, true),
554 Field::new("val", DataType::Float64, true),
555 ]));
556 DataSourceExec::new(Arc::new(
557 MemorySourceConfig::try_new(&[series], schema, None).unwrap(),
558 ))
559 }
560
561 async fn run_test(series: Vec<RecordBatch>, expected: &str) {
562 let memory_exec = Arc::new(prepare_test_data(series));
563 let schema = Arc::new(Schema::new(vec![
564 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
565 Field::new("val", DataType::Float64, true),
566 ]));
567 let properties = PlanProperties::new(
568 EquivalenceProperties::new(schema.clone()),
569 Partitioning::UnknownPartitioning(1),
570 EmissionType::Incremental,
571 Boundedness::Bounded,
572 );
573 let scalar_exec = Arc::new(ScalarCalculateExec {
574 start: 0,
575 end: 15_000,
576 interval: 5000,
577 tag_columns: vec!["tag1".to_string(), "tag2".to_string()],
578 input: memory_exec,
579 schema,
580 project_index: (0, 3),
581 metric: ExecutionPlanMetricsSet::new(),
582 properties,
583 });
584 let session_context = SessionContext::default();
585 let result = datafusion::physical_plan::collect(scalar_exec, session_context.task_ctx())
586 .await
587 .unwrap();
588 let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
589 .unwrap()
590 .to_string();
591 assert_eq!(result_literal, expected);
592 }
593
594 #[tokio::test]
595 async fn same_series() {
596 let schema = Arc::new(Schema::new(vec![
597 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
598 Field::new("tag1", DataType::Utf8, true),
599 Field::new("tag2", DataType::Utf8, true),
600 Field::new("val", DataType::Float64, true),
601 ]));
602 run_test(
603 vec![
604 RecordBatch::try_new(
605 schema.clone(),
606 vec![
607 Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
608 Arc::new(StringArray::from(vec!["foo", "foo"])),
609 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
610 Arc::new(Float64Array::from(vec![1.0, 2.0])),
611 ],
612 )
613 .unwrap(),
614 RecordBatch::try_new(
615 schema,
616 vec![
617 Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
618 Arc::new(StringArray::from(vec!["foo", "foo"])),
619 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
620 Arc::new(Float64Array::from(vec![3.0, 4.0])),
621 ],
622 )
623 .unwrap(),
624 ],
625 "+---------------------+-----+\
626 \n| ts | val |\
627 \n+---------------------+-----+\
628 \n| 1970-01-01T00:00:00 | 1.0 |\
629 \n| 1970-01-01T00:00:05 | 2.0 |\
630 \n| 1970-01-01T00:00:10 | 3.0 |\
631 \n| 1970-01-01T00:00:15 | 4.0 |\
632 \n+---------------------+-----+",
633 )
634 .await
635 }
636
637 #[tokio::test]
638 async fn diff_series() {
639 let schema = Arc::new(Schema::new(vec![
640 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
641 Field::new("tag1", DataType::Utf8, true),
642 Field::new("tag2", DataType::Utf8, true),
643 Field::new("val", DataType::Float64, true),
644 ]));
645 run_test(
646 vec![
647 RecordBatch::try_new(
648 schema.clone(),
649 vec![
650 Arc::new(TimestampMillisecondArray::from(vec![0, 5_000])),
651 Arc::new(StringArray::from(vec!["foo", "foo"])),
652 Arc::new(StringArray::from(vec!["🥺", "🥺"])),
653 Arc::new(Float64Array::from(vec![1.0, 2.0])),
654 ],
655 )
656 .unwrap(),
657 RecordBatch::try_new(
658 schema,
659 vec![
660 Arc::new(TimestampMillisecondArray::from(vec![10_000, 15_000])),
661 Arc::new(StringArray::from(vec!["foo", "foo"])),
662 Arc::new(StringArray::from(vec!["🥺", "😝"])),
663 Arc::new(Float64Array::from(vec![3.0, 4.0])),
664 ],
665 )
666 .unwrap(),
667 ],
668 "+---------------------+-----+\
669 \n| ts | val |\
670 \n+---------------------+-----+\
671 \n| 1970-01-01T00:00:00 | NaN |\
672 \n| 1970-01-01T00:00:05 | NaN |\
673 \n| 1970-01-01T00:00:10 | NaN |\
674 \n| 1970-01-01T00:00:15 | NaN |\
675 \n+---------------------+-----+",
676 )
677 .await
678 }
679
680 #[tokio::test]
681 async fn empty_series() {
682 let schema = Arc::new(Schema::new(vec![
683 Field::new("ts", DataType::Timestamp(TimeUnit::Millisecond, None), true),
684 Field::new("tag1", DataType::Utf8, true),
685 Field::new("tag2", DataType::Utf8, true),
686 Field::new("val", DataType::Float64, true),
687 ]));
688 run_test(
689 vec![RecordBatch::try_new(
690 schema,
691 vec![
692 Arc::new(TimestampMillisecondArray::new_null(0)),
693 Arc::new(StringArray::new_null(0)),
694 Arc::new(StringArray::new_null(0)),
695 Arc::new(Float64Array::new_null(0)),
696 ],
697 )
698 .unwrap()],
699 "+---------------------+-----+\
700 \n| ts | val |\
701 \n+---------------------+-----+\
702 \n| 1970-01-01T00:00:00 | NaN |\
703 \n| 1970-01-01T00:00:05 | NaN |\
704 \n| 1970-01-01T00:00:10 | NaN |\
705 \n| 1970-01-01T00:00:15 | NaN |\
706 \n+---------------------+-----+",
707 )
708 .await
709 }
710}