promql/extension_plan/
series_divide.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20use datafusion::arrow::array::{
21    Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
22};
23use datafusion::arrow::datatypes::{DataType, SchemaRef};
24use datafusion::arrow::record_batch::RecordBatch;
25use datafusion::common::{DFSchema, DFSchemaRef};
26use datafusion::error::Result as DataFusionResult;
27use datafusion::execution::context::TaskContext;
28use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, UserDefinedLogicalNodeCore};
29use datafusion::physical_expr::{LexRequirement, OrderingRequirements, PhysicalSortRequirement};
30use datafusion::physical_plan::expressions::Column as ColumnExpr;
31use datafusion::physical_plan::metrics::{
32    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet,
33};
34use datafusion::physical_plan::{
35    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, RecordBatchStream,
36    SendableRecordBatchStream,
37};
38use datafusion_expr::col;
39use datatypes::arrow::compute;
40use datatypes::compute::SortOptions;
41use futures::{Stream, StreamExt, ready};
42use greptime_proto::substrait_extension as pb;
43use prost::Message;
44use snafu::ResultExt;
45
46use crate::error::{DeserializeSnafu, Result};
47use crate::extension_plan::{METRIC_NUM_SERIES, resolve_column_name, serialize_column_index};
48use crate::metrics::PROMQL_SERIES_COUNT;
49
50enum TagIdentifier<'a> {
51    /// A group of raw string tag columns.
52    Raw(Vec<RawTagColumn<'a>>),
53    /// A single UInt64 identifier (tsid).
54    Id(&'a UInt64Array),
55}
56
57impl<'a> TagIdentifier<'a> {
58    fn try_new(batch: &'a RecordBatch, tag_indices: &[usize]) -> DataFusionResult<Self> {
59        match tag_indices {
60            [] => Ok(Self::Raw(Vec::new())),
61            [index] => {
62                let array = batch.column(*index);
63                if array.data_type() == &DataType::UInt64 {
64                    let array = array
65                        .as_any()
66                        .downcast_ref::<UInt64Array>()
67                        .ok_or_else(|| {
68                            datafusion::error::DataFusionError::Internal(
69                                "Failed to downcast tag column to UInt64Array".to_string(),
70                            )
71                        })?;
72                    Ok(Self::Id(array))
73                } else {
74                    Ok(Self::Raw(vec![RawTagColumn::try_new(array.as_ref())?]))
75                }
76            }
77            indices => Ok(Self::Raw(
78                indices
79                    .iter()
80                    .map(|index| RawTagColumn::try_new(batch.column(*index).as_ref()))
81                    .collect::<DataFusionResult<Vec<_>>>()?,
82            )),
83        }
84    }
85
86    fn equal_at(&self, left_row: usize, other: &Self, right_row: usize) -> DataFusionResult<bool> {
87        match (self, other) {
88            (Self::Id(left), Self::Id(right)) => {
89                if left.is_null(left_row) || right.is_null(right_row) {
90                    return Ok(left.is_null(left_row) && right.is_null(right_row));
91                }
92                Ok(left.value(left_row) == right.value(right_row))
93            }
94            (Self::Raw(left), Self::Raw(right)) => {
95                if left.len() != right.len() {
96                    return Err(datafusion::error::DataFusionError::Internal(format!(
97                        "Mismatched tag column count: left={}, right={}",
98                        left.len(),
99                        right.len()
100                    )));
101                }
102
103                for (left_column, right_column) in left.iter().zip(right.iter()) {
104                    if !left_column.equal_at(left_row, right_column, right_row)? {
105                        return Ok(false);
106                    }
107                }
108                Ok(true)
109            }
110            _ => Err(datafusion::error::DataFusionError::Internal(format!(
111                "Mismatched tag identifier types: left={:?}, right={:?}",
112                self.data_type(),
113                other.data_type()
114            ))),
115        }
116    }
117
118    fn data_type(&self) -> &'static str {
119        match self {
120            Self::Raw(_) => "Raw",
121            Self::Id(_) => "Id",
122        }
123    }
124}
125
126enum RawTagColumn<'a> {
127    Utf8(&'a StringArray),
128    LargeUtf8(&'a LargeStringArray),
129    Utf8View(&'a StringViewArray),
130}
131
132impl<'a> RawTagColumn<'a> {
133    fn try_new(array: &'a dyn Array) -> DataFusionResult<Self> {
134        match array.data_type() {
135            DataType::Utf8 => array
136                .as_any()
137                .downcast_ref::<StringArray>()
138                .map(Self::Utf8)
139                .ok_or_else(|| {
140                    datafusion::error::DataFusionError::Internal(
141                        "Failed to downcast tag column to StringArray".to_string(),
142                    )
143                }),
144            DataType::LargeUtf8 => array
145                .as_any()
146                .downcast_ref::<LargeStringArray>()
147                .map(Self::LargeUtf8)
148                .ok_or_else(|| {
149                    datafusion::error::DataFusionError::Internal(
150                        "Failed to downcast tag column to LargeStringArray".to_string(),
151                    )
152                }),
153            DataType::Utf8View => array
154                .as_any()
155                .downcast_ref::<StringViewArray>()
156                .map(Self::Utf8View)
157                .ok_or_else(|| {
158                    datafusion::error::DataFusionError::Internal(
159                        "Failed to downcast tag column to StringViewArray".to_string(),
160                    )
161                }),
162            other => Err(datafusion::error::DataFusionError::Internal(format!(
163                "Unsupported tag column type: {other:?}"
164            ))),
165        }
166    }
167
168    fn is_null(&self, row: usize) -> bool {
169        match self {
170            Self::Utf8(array) => array.is_null(row),
171            Self::LargeUtf8(array) => array.is_null(row),
172            Self::Utf8View(array) => array.is_null(row),
173        }
174    }
175
176    fn value(&self, row: usize) -> &str {
177        match self {
178            Self::Utf8(array) => array.value(row),
179            Self::LargeUtf8(array) => array.value(row),
180            Self::Utf8View(array) => array.value(row),
181        }
182    }
183
184    fn equal_at(&self, left_row: usize, other: &Self, right_row: usize) -> DataFusionResult<bool> {
185        if self.is_null(left_row) || other.is_null(right_row) {
186            return Ok(self.is_null(left_row) && other.is_null(right_row));
187        }
188
189        Ok(self.value(left_row) == other.value(right_row))
190    }
191}
192
193#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
194pub struct SeriesDivide {
195    tag_columns: Vec<String>,
196    /// `SeriesDivide` requires `time_index` column's name to generate ordering requirement
197    /// for input data. But this plan itself doesn't depend on the ordering of time index
198    /// column. This is for follow on plans like `RangeManipulate`. Because requiring ordering
199    /// here can avoid unnecessary sort in follow on plans.
200    time_index_column: String,
201    input: LogicalPlan,
202    unfix: Option<UnfixIndices>,
203}
204
205#[derive(Debug, PartialEq, Eq, Hash, PartialOrd)]
206struct UnfixIndices {
207    pub tag_column_indices: Vec<u64>,
208    pub time_index_column_idx: u64,
209}
210
211impl UserDefinedLogicalNodeCore for SeriesDivide {
212    fn name(&self) -> &str {
213        Self::name()
214    }
215
216    fn inputs(&self) -> Vec<&LogicalPlan> {
217        vec![&self.input]
218    }
219
220    fn schema(&self) -> &DFSchemaRef {
221        self.input.schema()
222    }
223
224    fn expressions(&self) -> Vec<Expr> {
225        if self.unfix.is_some() {
226            return vec![];
227        }
228
229        self.tag_columns
230            .iter()
231            .map(col)
232            .chain(std::iter::once(col(&self.time_index_column)))
233            .collect()
234    }
235
236    fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option<Vec<Vec<usize>>> {
237        if self.unfix.is_some() {
238            return None;
239        }
240
241        let input_schema = self.input.schema();
242        if output_columns.is_empty() {
243            let indices = (0..input_schema.fields().len()).collect::<Vec<_>>();
244            return Some(vec![indices]);
245        }
246
247        let mut required = Vec::with_capacity(output_columns.len() + 1 + self.tag_columns.len());
248        required.extend_from_slice(output_columns);
249        for tag in &self.tag_columns {
250            required.push(input_schema.index_of_column_by_name(None, tag)?);
251        }
252        required.push(input_schema.index_of_column_by_name(None, &self.time_index_column)?);
253
254        required.sort_unstable();
255        required.dedup();
256        Some(vec![required])
257    }
258
259    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
260        write!(f, "PromSeriesDivide: tags={:?}", self.tag_columns)
261    }
262
263    fn with_exprs_and_inputs(
264        &self,
265        _exprs: Vec<Expr>,
266        inputs: Vec<LogicalPlan>,
267    ) -> DataFusionResult<Self> {
268        if inputs.is_empty() {
269            return Err(datafusion::error::DataFusionError::Internal(
270                "SeriesDivide must have at least one input".to_string(),
271            ));
272        }
273
274        let input: LogicalPlan = inputs[0].clone();
275        let input_schema = input.schema();
276
277        if let Some(unfix) = &self.unfix {
278            // transform indices to names
279            let tag_columns = unfix
280                .tag_column_indices
281                .iter()
282                .map(|idx| resolve_column_name(*idx, input_schema, "SeriesDivide", "tag"))
283                .collect::<DataFusionResult<Vec<String>>>()?;
284
285            let time_index_column = resolve_column_name(
286                unfix.time_index_column_idx,
287                input_schema,
288                "SeriesDivide",
289                "time index",
290            )?;
291
292            Ok(Self {
293                tag_columns,
294                time_index_column,
295                input,
296                unfix: None,
297            })
298        } else {
299            Ok(Self {
300                tag_columns: self.tag_columns.clone(),
301                time_index_column: self.time_index_column.clone(),
302                input,
303                unfix: None,
304            })
305        }
306    }
307}
308
309impl SeriesDivide {
310    pub fn new(tag_columns: Vec<String>, time_index_column: String, input: LogicalPlan) -> Self {
311        Self {
312            tag_columns,
313            time_index_column,
314            input,
315            unfix: None,
316        }
317    }
318
319    pub const fn name() -> &'static str {
320        "SeriesDivide"
321    }
322
323    pub fn to_execution_plan(&self, exec_input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
324        Arc::new(SeriesDivideExec {
325            tag_columns: self.tag_columns.clone(),
326            time_index_column: self.time_index_column.clone(),
327            input: exec_input,
328            metric: ExecutionPlanMetricsSet::new(),
329        })
330    }
331
332    pub fn tags(&self) -> &[String] {
333        &self.tag_columns
334    }
335
336    pub fn serialize(&self) -> Vec<u8> {
337        let tag_column_indices = self
338            .tag_columns
339            .iter()
340            .map(|name| serialize_column_index(self.input.schema(), name))
341            .collect::<Vec<u64>>();
342
343        let time_index_column_idx =
344            serialize_column_index(self.input.schema(), &self.time_index_column);
345
346        pb::SeriesDivide {
347            tag_column_indices,
348            time_index_column_idx,
349            ..Default::default()
350        }
351        .encode_to_vec()
352    }
353
354    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
355        let pb_series_divide = pb::SeriesDivide::decode(bytes).context(DeserializeSnafu)?;
356        let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
357            produce_one_row: false,
358            schema: Arc::new(DFSchema::empty()),
359        });
360
361        let unfix = UnfixIndices {
362            tag_column_indices: pb_series_divide.tag_column_indices.clone(),
363            time_index_column_idx: pb_series_divide.time_index_column_idx,
364        };
365
366        Ok(Self {
367            tag_columns: Vec::new(),
368            time_index_column: String::new(),
369            input: placeholder_plan,
370            unfix: Some(unfix),
371        })
372    }
373}
374
375#[derive(Debug)]
376pub struct SeriesDivideExec {
377    tag_columns: Vec<String>,
378    time_index_column: String,
379    input: Arc<dyn ExecutionPlan>,
380    metric: ExecutionPlanMetricsSet,
381}
382
383impl ExecutionPlan for SeriesDivideExec {
384    fn as_any(&self) -> &dyn Any {
385        self
386    }
387
388    fn schema(&self) -> SchemaRef {
389        self.input.schema()
390    }
391
392    fn properties(&self) -> &Arc<PlanProperties> {
393        self.input.properties()
394    }
395
396    fn required_input_distribution(&self) -> Vec<Distribution> {
397        if self.tag_columns.is_empty() {
398            return vec![Distribution::SinglePartition];
399        }
400        let schema = self.input.schema();
401        vec![Distribution::HashPartitioned(
402            self.tag_columns
403                .iter()
404                // Safety: the tag column names is verified in the planning phase
405                .map(|tag| Arc::new(ColumnExpr::new_with_schema(tag, &schema).unwrap()) as _)
406                .collect(),
407        )]
408    }
409
410    fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
411        let input_schema = self.input.schema();
412        let mut exprs: Vec<PhysicalSortRequirement> = self
413            .tag_columns
414            .iter()
415            .map(|tag| PhysicalSortRequirement {
416                // Safety: the tag column names is verified in the planning phase
417                expr: Arc::new(ColumnExpr::new_with_schema(tag, &input_schema).unwrap()),
418                options: Some(SortOptions {
419                    descending: false,
420                    nulls_first: true,
421                }),
422            })
423            .collect();
424
425        exprs.push(PhysicalSortRequirement {
426            expr: Arc::new(
427                ColumnExpr::new_with_schema(&self.time_index_column, &input_schema).unwrap(),
428            ),
429            options: Some(SortOptions {
430                descending: false,
431                nulls_first: true,
432            }),
433        });
434
435        // Safety: `exprs` is not empty
436        let requirement = LexRequirement::new(exprs).unwrap();
437
438        vec![Some(OrderingRequirements::Hard(vec![requirement]))]
439    }
440
441    fn maintains_input_order(&self) -> Vec<bool> {
442        vec![true; self.children().len()]
443    }
444
445    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
446        vec![&self.input]
447    }
448
449    fn with_new_children(
450        self: Arc<Self>,
451        children: Vec<Arc<dyn ExecutionPlan>>,
452    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
453        assert!(!children.is_empty());
454        Ok(Arc::new(Self {
455            tag_columns: self.tag_columns.clone(),
456            time_index_column: self.time_index_column.clone(),
457            input: children[0].clone(),
458            metric: self.metric.clone(),
459        }))
460    }
461
462    fn execute(
463        &self,
464        partition: usize,
465        context: Arc<TaskContext>,
466    ) -> DataFusionResult<SendableRecordBatchStream> {
467        let baseline_metric = BaselineMetrics::new(&self.metric, partition);
468        let metrics_builder = MetricBuilder::new(&self.metric);
469        let num_series = Count::new();
470        metrics_builder
471            .with_partition(partition)
472            .build(MetricValue::Count {
473                name: METRIC_NUM_SERIES.into(),
474                count: num_series.clone(),
475            });
476
477        let input = self.input.execute(partition, context)?;
478        let schema = input.schema();
479        let tag_indices = self
480            .tag_columns
481            .iter()
482            .map(|tag| {
483                schema
484                    .column_with_name(tag)
485                    .unwrap_or_else(|| panic!("tag column not found {tag}"))
486                    .0
487            })
488            .collect();
489        Ok(Box::pin(SeriesDivideStream {
490            tag_indices,
491            buffer: vec![],
492            schema,
493            input,
494            metric: baseline_metric,
495            num_series,
496            inspect_start: 0,
497        }))
498    }
499
500    fn metrics(&self) -> Option<MetricsSet> {
501        Some(self.metric.clone_inner())
502    }
503
504    fn name(&self) -> &str {
505        "SeriesDivideExec"
506    }
507}
508
509impl DisplayAs for SeriesDivideExec {
510    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
511        match t {
512            DisplayFormatType::Default
513            | DisplayFormatType::Verbose
514            | DisplayFormatType::TreeRender => {
515                write!(f, "PromSeriesDivideExec: tags={:?}", self.tag_columns)
516            }
517        }
518    }
519}
520
521/// Assume the input stream is ordered on the tag columns.
522pub struct SeriesDivideStream {
523    tag_indices: Vec<usize>,
524    buffer: Vec<RecordBatch>,
525    schema: SchemaRef,
526    input: SendableRecordBatchStream,
527    metric: BaselineMetrics,
528    /// Index of buffered batches to start inspect next time.
529    inspect_start: usize,
530    /// Number of series processed.
531    num_series: Count,
532}
533
534impl RecordBatchStream for SeriesDivideStream {
535    fn schema(&self) -> SchemaRef {
536        self.schema.clone()
537    }
538}
539
540impl Stream for SeriesDivideStream {
541    type Item = DataFusionResult<RecordBatch>;
542
543    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
544        loop {
545            if !self.buffer.is_empty() {
546                let timer = std::time::Instant::now();
547                let cut_at = match self.find_first_diff_row() {
548                    Ok(cut_at) => cut_at,
549                    Err(e) => return Poll::Ready(Some(Err(e))),
550                };
551                if let Some((batch_index, row_index)) = cut_at {
552                    // slice out the first time series and return it.
553                    let half_batch_of_first_series =
554                        self.buffer[batch_index].slice(0, row_index + 1);
555                    let half_batch_of_second_series = self.buffer[batch_index].slice(
556                        row_index + 1,
557                        self.buffer[batch_index].num_rows() - row_index - 1,
558                    );
559                    let result_batches = self
560                        .buffer
561                        .drain(0..batch_index)
562                        .chain([half_batch_of_first_series])
563                        .collect::<Vec<_>>();
564                    if half_batch_of_second_series.num_rows() > 0 {
565                        self.buffer[0] = half_batch_of_second_series;
566                    } else {
567                        self.buffer.remove(0);
568                    }
569                    let result_batch = compute::concat_batches(&self.schema, &result_batches)?;
570
571                    self.inspect_start = 0;
572                    self.num_series.add(1);
573                    self.metric.elapsed_compute().add_elapsed(timer);
574                    return Poll::Ready(Some(Ok(result_batch)));
575                } else {
576                    self.metric.elapsed_compute().add_elapsed(timer);
577                    // continue to fetch next batch as the current buffer only contains one time series.
578                    let next_batch = ready!(self.as_mut().fetch_next_batch(cx)).transpose()?;
579                    let timer = std::time::Instant::now();
580                    if let Some(next_batch) = next_batch {
581                        if next_batch.num_rows() != 0 {
582                            self.buffer.push(next_batch);
583                        }
584                        continue;
585                    } else {
586                        // input stream is ended
587                        let result = compute::concat_batches(&self.schema, &self.buffer)?;
588                        self.buffer.clear();
589                        self.inspect_start = 0;
590                        self.num_series.add(1);
591                        self.metric.elapsed_compute().add_elapsed(timer);
592                        return Poll::Ready(Some(Ok(result)));
593                    }
594                }
595            } else {
596                let batch = match ready!(self.as_mut().fetch_next_batch(cx)) {
597                    Some(Ok(batch)) => batch,
598                    None => {
599                        PROMQL_SERIES_COUNT.observe(self.num_series.value() as f64);
600                        return Poll::Ready(None);
601                    }
602                    error => return Poll::Ready(error),
603                };
604                self.buffer.push(batch);
605                continue;
606            }
607        }
608    }
609}
610
611impl SeriesDivideStream {
612    fn fetch_next_batch(
613        mut self: Pin<&mut Self>,
614        cx: &mut Context<'_>,
615    ) -> Poll<Option<DataFusionResult<RecordBatch>>> {
616        let poll = self.input.poll_next_unpin(cx);
617        self.metric.record_poll(poll)
618    }
619
620    /// Return the position to cut buffer.
621    /// None implies the current buffer only contains one time series.
622    fn find_first_diff_row(&mut self) -> DataFusionResult<Option<(usize, usize)>> {
623        // fast path: no tag columns means all data belongs to the same series.
624        if self.tag_indices.is_empty() {
625            return Ok(None);
626        }
627
628        let mut resumed_batch_index = self.inspect_start;
629
630        for batch in &self.buffer[resumed_batch_index..] {
631            let num_rows = batch.num_rows();
632            let tags = TagIdentifier::try_new(batch, &self.tag_indices)?;
633
634            // check if the first row is the same with last batch's last row
635            if resumed_batch_index > self.inspect_start.checked_sub(1).unwrap_or_default() {
636                let last_batch = &self.buffer[resumed_batch_index - 1];
637                let last_row = last_batch.num_rows() - 1;
638                let last_tags = TagIdentifier::try_new(last_batch, &self.tag_indices)?;
639                if !tags.equal_at(0, &last_tags, last_row)? {
640                    return Ok(Some((resumed_batch_index - 1, last_row)));
641                }
642            }
643
644            // quick check if all rows are the same by comparing the first and last row in this batch
645            if tags.equal_at(0, &tags, num_rows - 1)? {
646                resumed_batch_index += 1;
647                continue;
648            }
649
650            let mut same_until = 0;
651            while same_until < num_rows - 1 {
652                if !tags.equal_at(same_until, &tags, same_until + 1)? {
653                    break;
654                }
655                same_until += 1;
656            }
657
658            if same_until + 1 >= num_rows {
659                // all rows are the same, inspect next batch
660                resumed_batch_index += 1;
661            } else {
662                return Ok(Some((resumed_batch_index, same_until)));
663            }
664        }
665
666        self.inspect_start = resumed_batch_index;
667        Ok(None)
668    }
669}
670
671#[cfg(test)]
672mod test {
673    use datafusion::arrow::datatypes::{DataType, Field, Schema};
674    use datafusion::common::ToDFSchema;
675    use datafusion::datasource::memory::MemorySourceConfig;
676    use datafusion::datasource::source::DataSourceExec;
677    use datafusion::logical_expr::{EmptyRelation, LogicalPlan};
678    use datafusion::prelude::SessionContext;
679
680    use super::*;
681
682    fn prepare_test_data() -> DataSourceExec {
683        let schema = Arc::new(Schema::new(vec![
684            Field::new("host", DataType::Utf8, true),
685            Field::new("path", DataType::Utf8, true),
686            Field::new(
687                "time_index",
688                DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None),
689                false,
690            ),
691        ]));
692
693        let path_column_1 = Arc::new(StringArray::from(vec![
694            "foo", "foo", "foo", "bar", "bar", "bar", "bar", "bar", "bar", "bla", "bla", "bla",
695        ])) as _;
696        let host_column_1 = Arc::new(StringArray::from(vec![
697            "000", "000", "001", "002", "002", "002", "002", "002", "003", "005", "005", "005",
698        ])) as _;
699        let time_index_column_1 = Arc::new(
700            datafusion::arrow::array::TimestampMillisecondArray::from(vec![
701                1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000,
702            ]),
703        ) as _;
704
705        let path_column_2 = Arc::new(StringArray::from(vec!["bla", "bla", "bla"])) as _;
706        let host_column_2 = Arc::new(StringArray::from(vec!["005", "005", "005"])) as _;
707        let time_index_column_2 = Arc::new(
708            datafusion::arrow::array::TimestampMillisecondArray::from(vec![13000, 14000, 15000]),
709        ) as _;
710
711        let path_column_3 = Arc::new(StringArray::from(vec![
712            "bla", "🥺", "🥺", "🥺", "🥺", "🥺", "🫠", "🫠",
713        ])) as _;
714        let host_column_3 = Arc::new(StringArray::from(vec![
715            "005", "001", "001", "001", "001", "001", "001", "001",
716        ])) as _;
717        let time_index_column_3 =
718            Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
719                vec![16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000],
720            )) as _;
721
722        let data_1 = RecordBatch::try_new(
723            schema.clone(),
724            vec![path_column_1, host_column_1, time_index_column_1],
725        )
726        .unwrap();
727        let data_2 = RecordBatch::try_new(
728            schema.clone(),
729            vec![path_column_2, host_column_2, time_index_column_2],
730        )
731        .unwrap();
732        let data_3 = RecordBatch::try_new(
733            schema.clone(),
734            vec![path_column_3, host_column_3, time_index_column_3],
735        )
736        .unwrap();
737
738        DataSourceExec::new(Arc::new(
739            MemorySourceConfig::try_new(&[vec![data_1, data_2, data_3]], schema, None).unwrap(),
740        ))
741    }
742
743    #[test]
744    fn pruning_should_keep_tags_and_time_index_columns_for_exec() {
745        let df_schema = prepare_test_data().schema().to_dfschema_ref().unwrap();
746        let input = LogicalPlan::EmptyRelation(EmptyRelation {
747            produce_one_row: false,
748            schema: df_schema,
749        });
750        let plan = SeriesDivide::new(
751            vec!["host".to_string(), "path".to_string()],
752            "time_index".to_string(),
753            input,
754        );
755
756        // Simulate a parent projection requesting only the `host` column.
757        let output_columns = [0usize];
758        let required = plan.necessary_children_exprs(&output_columns).unwrap();
759        let required = &required[0];
760        assert_eq!(required.as_slice(), &[0, 1, 2]);
761    }
762
763    #[tokio::test]
764    async fn overall_data() {
765        let memory_exec = Arc::new(prepare_test_data());
766        let divide_exec = Arc::new(SeriesDivideExec {
767            tag_columns: vec!["host".to_string(), "path".to_string()],
768            time_index_column: "time_index".to_string(),
769            input: memory_exec,
770            metric: ExecutionPlanMetricsSet::new(),
771        });
772        let session_context = SessionContext::default();
773        let result = datafusion::physical_plan::collect(divide_exec, session_context.task_ctx())
774            .await
775            .unwrap();
776        let result_literal = datatypes::arrow::util::pretty::pretty_format_batches(&result)
777            .unwrap()
778            .to_string();
779
780        let expected = String::from(
781            "+------+------+---------------------+\
782            \n| host | path | time_index          |\
783            \n+------+------+---------------------+\
784            \n| foo  | 000  | 1970-01-01T00:00:01 |\
785            \n| foo  | 000  | 1970-01-01T00:00:02 |\
786            \n| foo  | 001  | 1970-01-01T00:00:03 |\
787            \n| bar  | 002  | 1970-01-01T00:00:04 |\
788            \n| bar  | 002  | 1970-01-01T00:00:05 |\
789            \n| bar  | 002  | 1970-01-01T00:00:06 |\
790            \n| bar  | 002  | 1970-01-01T00:00:07 |\
791            \n| bar  | 002  | 1970-01-01T00:00:08 |\
792            \n| bar  | 003  | 1970-01-01T00:00:09 |\
793            \n| bla  | 005  | 1970-01-01T00:00:10 |\
794            \n| bla  | 005  | 1970-01-01T00:00:11 |\
795            \n| bla  | 005  | 1970-01-01T00:00:12 |\
796            \n| bla  | 005  | 1970-01-01T00:00:13 |\
797            \n| bla  | 005  | 1970-01-01T00:00:14 |\
798            \n| bla  | 005  | 1970-01-01T00:00:15 |\
799            \n| bla  | 005  | 1970-01-01T00:00:16 |\
800            \n| 🥺   | 001  | 1970-01-01T00:00:17 |\
801            \n| 🥺   | 001  | 1970-01-01T00:00:18 |\
802            \n| 🥺   | 001  | 1970-01-01T00:00:19 |\
803            \n| 🥺   | 001  | 1970-01-01T00:00:20 |\
804            \n| 🥺   | 001  | 1970-01-01T00:00:21 |\
805            \n| 🫠   | 001  | 1970-01-01T00:00:22 |\
806            \n| 🫠   | 001  | 1970-01-01T00:00:23 |\
807            \n+------+------+---------------------+",
808        );
809        assert_eq!(result_literal, expected);
810    }
811
812    #[tokio::test]
813    async fn per_batch_data() {
814        let memory_exec = Arc::new(prepare_test_data());
815        let divide_exec = Arc::new(SeriesDivideExec {
816            tag_columns: vec!["host".to_string(), "path".to_string()],
817            time_index_column: "time_index".to_string(),
818            input: memory_exec,
819            metric: ExecutionPlanMetricsSet::new(),
820        });
821        let mut divide_stream = divide_exec
822            .execute(0, SessionContext::default().task_ctx())
823            .unwrap();
824
825        let mut expectations = vec![
826            String::from(
827                "+------+------+---------------------+\
828                \n| host | path | time_index          |\
829                \n+------+------+---------------------+\
830                \n| foo  | 000  | 1970-01-01T00:00:01 |\
831                \n| foo  | 000  | 1970-01-01T00:00:02 |\
832                \n+------+------+---------------------+",
833            ),
834            String::from(
835                "+------+------+---------------------+\
836                \n| host | path | time_index          |\
837                \n+------+------+---------------------+\
838                \n| foo  | 001  | 1970-01-01T00:00:03 |\
839                \n+------+------+---------------------+",
840            ),
841            String::from(
842                "+------+------+---------------------+\
843                \n| host | path | time_index          |\
844                \n+------+------+---------------------+\
845                \n| bar  | 002  | 1970-01-01T00:00:04 |\
846                \n| bar  | 002  | 1970-01-01T00:00:05 |\
847                \n| bar  | 002  | 1970-01-01T00:00:06 |\
848                \n| bar  | 002  | 1970-01-01T00:00:07 |\
849                \n| bar  | 002  | 1970-01-01T00:00:08 |\
850                \n+------+------+---------------------+",
851            ),
852            String::from(
853                "+------+------+---------------------+\
854                \n| host | path | time_index          |\
855                \n+------+------+---------------------+\
856                \n| bar  | 003  | 1970-01-01T00:00:09 |\
857                \n+------+------+---------------------+",
858            ),
859            String::from(
860                "+------+------+---------------------+\
861                \n| host | path | time_index          |\
862                \n+------+------+---------------------+\
863                \n| bla  | 005  | 1970-01-01T00:00:10 |\
864                \n| bla  | 005  | 1970-01-01T00:00:11 |\
865                \n| bla  | 005  | 1970-01-01T00:00:12 |\
866                \n| bla  | 005  | 1970-01-01T00:00:13 |\
867                \n| bla  | 005  | 1970-01-01T00:00:14 |\
868                \n| bla  | 005  | 1970-01-01T00:00:15 |\
869                \n| bla  | 005  | 1970-01-01T00:00:16 |\
870                \n+------+------+---------------------+",
871            ),
872            String::from(
873                "+------+------+---------------------+\
874                \n| host | path | time_index          |\
875                \n+------+------+---------------------+\
876                \n| 🥺   | 001  | 1970-01-01T00:00:17 |\
877                \n| 🥺   | 001  | 1970-01-01T00:00:18 |\
878                \n| 🥺   | 001  | 1970-01-01T00:00:19 |\
879                \n| 🥺   | 001  | 1970-01-01T00:00:20 |\
880                \n| 🥺   | 001  | 1970-01-01T00:00:21 |\
881                \n+------+------+---------------------+",
882            ),
883            String::from(
884                "+------+------+---------------------+\
885                \n| host | path | time_index          |\
886                \n+------+------+---------------------+\
887                \n| 🫠   | 001  | 1970-01-01T00:00:22 |\
888                \n| 🫠   | 001  | 1970-01-01T00:00:23 |\
889                \n+------+------+---------------------+",
890            ),
891        ];
892        expectations.reverse();
893
894        while let Some(batch) = divide_stream.next().await {
895            let formatted =
896                datatypes::arrow::util::pretty::pretty_format_batches(&[batch.unwrap()])
897                    .unwrap()
898                    .to_string();
899            let expected = expectations.pop().unwrap();
900            assert_eq!(formatted, expected);
901        }
902    }
903
904    #[tokio::test]
905    async fn test_all_batches_same_combination() {
906        // Create a schema with host and path columns, same as prepare_test_data
907        let schema = Arc::new(Schema::new(vec![
908            Field::new("host", DataType::Utf8, true),
909            Field::new("path", DataType::Utf8, true),
910            Field::new(
911                "time_index",
912                DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None),
913                false,
914            ),
915        ]));
916
917        // Create batches with three different combinations
918        // Each batch contains only one combination
919        // Batches with the same combination are adjacent
920
921        // First combination: "server1", "/var/log"
922        let batch1 = RecordBatch::try_new(
923            schema.clone(),
924            vec![
925                Arc::new(StringArray::from(vec!["server1", "server1", "server1"])) as _,
926                Arc::new(StringArray::from(vec!["/var/log", "/var/log", "/var/log"])) as _,
927                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
928                    vec![1000, 2000, 3000],
929                )) as _,
930            ],
931        )
932        .unwrap();
933
934        let batch2 = RecordBatch::try_new(
935            schema.clone(),
936            vec![
937                Arc::new(StringArray::from(vec!["server1", "server1"])) as _,
938                Arc::new(StringArray::from(vec!["/var/log", "/var/log"])) as _,
939                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
940                    vec![4000, 5000],
941                )) as _,
942            ],
943        )
944        .unwrap();
945
946        // Second combination: "server2", "/var/data"
947        let batch3 = RecordBatch::try_new(
948            schema.clone(),
949            vec![
950                Arc::new(StringArray::from(vec!["server2", "server2", "server2"])) as _,
951                Arc::new(StringArray::from(vec![
952                    "/var/data",
953                    "/var/data",
954                    "/var/data",
955                ])) as _,
956                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
957                    vec![6000, 7000, 8000],
958                )) as _,
959            ],
960        )
961        .unwrap();
962
963        let batch4 = RecordBatch::try_new(
964            schema.clone(),
965            vec![
966                Arc::new(StringArray::from(vec!["server2"])) as _,
967                Arc::new(StringArray::from(vec!["/var/data"])) as _,
968                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
969                    vec![9000],
970                )) as _,
971            ],
972        )
973        .unwrap();
974
975        // Third combination: "server3", "/opt/logs"
976        let batch5 = RecordBatch::try_new(
977            schema.clone(),
978            vec![
979                Arc::new(StringArray::from(vec!["server3", "server3"])) as _,
980                Arc::new(StringArray::from(vec!["/opt/logs", "/opt/logs"])) as _,
981                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
982                    vec![10000, 11000],
983                )) as _,
984            ],
985        )
986        .unwrap();
987
988        let batch6 = RecordBatch::try_new(
989            schema.clone(),
990            vec![
991                Arc::new(StringArray::from(vec!["server3", "server3", "server3"])) as _,
992                Arc::new(StringArray::from(vec![
993                    "/opt/logs",
994                    "/opt/logs",
995                    "/opt/logs",
996                ])) as _,
997                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
998                    vec![12000, 13000, 14000],
999                )) as _,
1000            ],
1001        )
1002        .unwrap();
1003
1004        // Create MemoryExec with these batches, keeping same combinations adjacent
1005        let memory_exec = DataSourceExec::from_data_source(
1006            MemorySourceConfig::try_new(
1007                &[vec![batch1, batch2, batch3, batch4, batch5, batch6]],
1008                schema.clone(),
1009                None,
1010            )
1011            .unwrap(),
1012        );
1013
1014        // Create SeriesDivideExec
1015        let divide_exec = Arc::new(SeriesDivideExec {
1016            tag_columns: vec!["host".to_string(), "path".to_string()],
1017            time_index_column: "time_index".to_string(),
1018            input: memory_exec,
1019            metric: ExecutionPlanMetricsSet::new(),
1020        });
1021
1022        // Execute the division
1023        let session_context = SessionContext::default();
1024        let result =
1025            datafusion::physical_plan::collect(divide_exec.clone(), session_context.task_ctx())
1026                .await
1027                .unwrap();
1028
1029        // Verify that we got 3 batches (one for each combination)
1030        assert_eq!(result.len(), 3);
1031
1032        // First batch should have 5 rows (3 + 2 from the "server1" combination)
1033        assert_eq!(result[0].num_rows(), 5);
1034
1035        // Second batch should have 4 rows (3 + 1 from the "server2" combination)
1036        assert_eq!(result[1].num_rows(), 4);
1037
1038        // Third batch should have 5 rows (2 + 3 from the "server3" combination)
1039        assert_eq!(result[2].num_rows(), 5);
1040
1041        // Verify values in first batch (server1, /var/log)
1042        let host_array1 = result[0]
1043            .column(0)
1044            .as_any()
1045            .downcast_ref::<StringArray>()
1046            .unwrap();
1047        let path_array1 = result[0]
1048            .column(1)
1049            .as_any()
1050            .downcast_ref::<StringArray>()
1051            .unwrap();
1052        let time_index_array1 = result[0]
1053            .column(2)
1054            .as_any()
1055            .downcast_ref::<datafusion::arrow::array::TimestampMillisecondArray>()
1056            .unwrap();
1057
1058        for i in 0..5 {
1059            assert_eq!(host_array1.value(i), "server1");
1060            assert_eq!(path_array1.value(i), "/var/log");
1061            assert_eq!(time_index_array1.value(i), 1000 + (i as i64) * 1000);
1062        }
1063
1064        // Verify values in second batch (server2, /var/data)
1065        let host_array2 = result[1]
1066            .column(0)
1067            .as_any()
1068            .downcast_ref::<StringArray>()
1069            .unwrap();
1070        let path_array2 = result[1]
1071            .column(1)
1072            .as_any()
1073            .downcast_ref::<StringArray>()
1074            .unwrap();
1075        let time_index_array2 = result[1]
1076            .column(2)
1077            .as_any()
1078            .downcast_ref::<datafusion::arrow::array::TimestampMillisecondArray>()
1079            .unwrap();
1080
1081        for i in 0..4 {
1082            assert_eq!(host_array2.value(i), "server2");
1083            assert_eq!(path_array2.value(i), "/var/data");
1084            assert_eq!(time_index_array2.value(i), 6000 + (i as i64) * 1000);
1085        }
1086
1087        // Verify values in third batch (server3, /opt/logs)
1088        let host_array3 = result[2]
1089            .column(0)
1090            .as_any()
1091            .downcast_ref::<StringArray>()
1092            .unwrap();
1093        let path_array3 = result[2]
1094            .column(1)
1095            .as_any()
1096            .downcast_ref::<StringArray>()
1097            .unwrap();
1098        let time_index_array3 = result[2]
1099            .column(2)
1100            .as_any()
1101            .downcast_ref::<datafusion::arrow::array::TimestampMillisecondArray>()
1102            .unwrap();
1103
1104        for i in 0..5 {
1105            assert_eq!(host_array3.value(i), "server3");
1106            assert_eq!(path_array3.value(i), "/opt/logs");
1107            assert_eq!(time_index_array3.value(i), 10000 + (i as i64) * 1000);
1108        }
1109
1110        // Also verify streaming behavior
1111        let mut divide_stream = divide_exec
1112            .execute(0, SessionContext::default().task_ctx())
1113            .unwrap();
1114
1115        // Should produce three batches, one for each combination
1116        let batch1 = divide_stream.next().await.unwrap().unwrap();
1117        assert_eq!(batch1.num_rows(), 5); // server1 combination
1118
1119        let batch2 = divide_stream.next().await.unwrap().unwrap();
1120        assert_eq!(batch2.num_rows(), 4); // server2 combination
1121
1122        let batch3 = divide_stream.next().await.unwrap().unwrap();
1123        assert_eq!(batch3.num_rows(), 5); // server3 combination
1124
1125        // No more batches should be produced
1126        assert!(divide_stream.next().await.is_none());
1127    }
1128
1129    #[tokio::test]
1130    async fn test_string_tag_column_types() {
1131        let schema = Arc::new(Schema::new(vec![
1132            Field::new("tag_large", DataType::LargeUtf8, false),
1133            Field::new("tag_view", DataType::Utf8View, false),
1134            Field::new(
1135                "time_index",
1136                DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None),
1137                false,
1138            ),
1139        ]));
1140
1141        let batch1 = RecordBatch::try_new(
1142            schema.clone(),
1143            vec![
1144                Arc::new(LargeStringArray::from(vec!["a", "a", "a", "a"])),
1145                Arc::new(StringViewArray::from(vec!["x", "x", "y", "y"])),
1146                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
1147                    vec![1000, 2000, 1000, 2000],
1148                )),
1149            ],
1150        )
1151        .unwrap();
1152
1153        let batch2 = RecordBatch::try_new(
1154            schema.clone(),
1155            vec![
1156                Arc::new(LargeStringArray::from(vec!["b", "b"])),
1157                Arc::new(StringViewArray::from(vec!["x", "x"])),
1158                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
1159                    vec![1000, 2000],
1160                )),
1161            ],
1162        )
1163        .unwrap();
1164
1165        let memory_exec: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
1166            MemorySourceConfig::try_new(&[vec![batch1, batch2]], schema.clone(), None).unwrap(),
1167        )));
1168
1169        let divide_exec = Arc::new(SeriesDivideExec {
1170            tag_columns: vec!["tag_large".to_string(), "tag_view".to_string()],
1171            time_index_column: "time_index".to_string(),
1172            input: memory_exec,
1173            metric: ExecutionPlanMetricsSet::new(),
1174        });
1175
1176        let session_context = SessionContext::default();
1177        let result = datafusion::physical_plan::collect(divide_exec, session_context.task_ctx())
1178            .await
1179            .unwrap();
1180
1181        assert_eq!(result.len(), 3);
1182        for ((expected_large, expected_view), batch) in [("a", "x"), ("a", "y"), ("b", "x")]
1183            .into_iter()
1184            .zip(result.iter())
1185        {
1186            assert_eq!(batch.num_rows(), 2);
1187
1188            let tag_large_array = batch
1189                .column(0)
1190                .as_any()
1191                .downcast_ref::<LargeStringArray>()
1192                .unwrap();
1193            let tag_view_array = batch
1194                .column(1)
1195                .as_any()
1196                .downcast_ref::<StringViewArray>()
1197                .unwrap();
1198
1199            for row in 0..batch.num_rows() {
1200                assert_eq!(tag_large_array.value(row), expected_large);
1201                assert_eq!(tag_view_array.value(row), expected_view);
1202            }
1203        }
1204    }
1205
1206    #[tokio::test]
1207    async fn test_u64_tag_column() {
1208        let schema = Arc::new(Schema::new(vec![
1209            Field::new("tsid", DataType::UInt64, false),
1210            Field::new(
1211                "time_index",
1212                DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Millisecond, None),
1213                false,
1214            ),
1215        ]));
1216
1217        let batch1 = RecordBatch::try_new(
1218            schema.clone(),
1219            vec![
1220                Arc::new(UInt64Array::from(vec![1, 1, 2, 2])),
1221                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
1222                    vec![1000, 2000, 1000, 2000],
1223                )),
1224            ],
1225        )
1226        .unwrap();
1227
1228        let batch2 = RecordBatch::try_new(
1229            schema.clone(),
1230            vec![
1231                Arc::new(UInt64Array::from(vec![3, 3])),
1232                Arc::new(datafusion::arrow::array::TimestampMillisecondArray::from(
1233                    vec![1000, 2000],
1234                )),
1235            ],
1236        )
1237        .unwrap();
1238
1239        let memory_exec: Arc<dyn ExecutionPlan> = Arc::new(DataSourceExec::new(Arc::new(
1240            MemorySourceConfig::try_new(&[vec![batch1, batch2]], schema.clone(), None).unwrap(),
1241        )));
1242
1243        let divide_exec = Arc::new(SeriesDivideExec {
1244            tag_columns: vec!["tsid".to_string()],
1245            time_index_column: "time_index".to_string(),
1246            input: memory_exec,
1247            metric: ExecutionPlanMetricsSet::new(),
1248        });
1249
1250        let session_context = SessionContext::default();
1251        let result = datafusion::physical_plan::collect(divide_exec, session_context.task_ctx())
1252            .await
1253            .unwrap();
1254
1255        assert_eq!(result.len(), 3);
1256        for (expected_tsid, batch) in [1u64, 2u64, 3u64].into_iter().zip(result.iter()) {
1257            assert_eq!(batch.num_rows(), 2);
1258            let tsid_array = batch
1259                .column(0)
1260                .as_any()
1261                .downcast_ref::<UInt64Array>()
1262                .unwrap();
1263            assert!(tsid_array.iter().all(|v| v == Some(expected_tsid)));
1264        }
1265    }
1266}