common_query/logical_plan/
accumulator.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Accumulator module contains the trait definition for aggregation function's accumulators.
16
17use std::fmt::Debug;
18use std::sync::Arc;
19
20use datafusion_common::Result as DfResult;
21use datafusion_expr::Accumulator as DfAccumulator;
22use datatypes::arrow::array::ArrayRef;
23use datatypes::prelude::*;
24use datatypes::vectors::{Helper as VectorHelper, VectorRef};
25use snafu::ResultExt;
26
27use crate::error::{self, FromScalarValueSnafu, IntoVectorSnafu, Result};
28use crate::prelude::*;
29
30pub type AggregateFunctionCreatorRef = Arc<dyn AggregateFunctionCreator>;
31
32/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
33/// generically accumulates values.
34///
35/// An accumulator knows how to:
36/// * update its state from inputs via `update_batch`
37/// * convert its internal state to a vector of scalar values
38/// * update its state from multiple accumulators' states via `merge_batch`
39/// * compute the final value from its internal state via `evaluate`
40///
41/// Modified from DataFusion.
42pub trait Accumulator: Send + Sync + Debug {
43    /// Returns the state of the accumulator at the end of the accumulation.
44    // in the case of an average on which we track `sum` and `n`, this function should return a vector
45    // of two values, sum and n.
46    fn state(&self) -> Result<Vec<Value>>;
47
48    /// updates the accumulator's state from a vector of arrays.
49    fn update_batch(&mut self, values: &[VectorRef]) -> Result<()>;
50
51    /// updates the accumulator's state from a vector of states.
52    fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()>;
53
54    /// returns its value based on its current state.
55    fn evaluate(&self) -> Result<Value>;
56}
57
58/// An `AggregateFunctionCreator` dynamically creates `Accumulator`.
59///
60/// An `AggregateFunctionCreator` often has a companion struct, that
61/// can store the input data types (impl [AggrFuncTypeStore]), and knows the output and states
62/// types of an Accumulator.
63pub trait AggregateFunctionCreator: AggrFuncTypeStore {
64    /// Create a function that can create a new accumulator with some input data type.
65    fn creator(&self) -> AccumulatorCreatorFunction;
66
67    /// Get the Accumulator's output data type.
68    fn output_type(&self) -> Result<ConcreteDataType>;
69
70    /// Get the Accumulator's state data types.
71    fn state_types(&self) -> Result<Vec<ConcreteDataType>>;
72}
73
74/// `AggrFuncTypeStore` stores the aggregate function's input data's types.
75///
76/// When creating Accumulator generically, we have to know the input data's types.
77/// However, DataFusion does not provide the input data's types at the time of creating Accumulator.
78/// To solve the problem, we store the datatypes upfront here.
79pub trait AggrFuncTypeStore: Send + Sync + Debug {
80    /// Get the input data types of the Accumulator.
81    fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
82
83    /// Store the input data types that are provided by DataFusion at runtime (when it is evaluating
84    /// return type function).
85    fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
86}
87
88pub fn make_accumulator_function(
89    creator: Arc<dyn AggregateFunctionCreator>,
90) -> AccumulatorFunctionImpl {
91    Arc::new(move || {
92        let input_types = creator.input_types()?;
93        let creator = creator.creator();
94        creator(&input_types)
95    })
96}
97
98pub fn make_return_function(creator: Arc<dyn AggregateFunctionCreator>) -> ReturnTypeFunction {
99    Arc::new(move |input_types| {
100        creator.set_input_types(input_types.to_vec())?;
101
102        let output_type = creator.output_type()?;
103        Ok(Arc::new(output_type))
104    })
105}
106
107pub fn make_state_function(creator: Arc<dyn AggregateFunctionCreator>) -> StateTypeFunction {
108    Arc::new(move |_| Ok(Arc::new(creator.state_types()?)))
109}
110
111/// A wrapper type for our Accumulator to DataFusion's Accumulator,
112/// so to make our Accumulator able to be executed by DataFusion query engine.
113#[derive(Debug)]
114pub struct DfAccumulatorAdaptor {
115    accumulator: Box<dyn Accumulator>,
116    creator: AggregateFunctionCreatorRef,
117}
118
119impl DfAccumulatorAdaptor {
120    pub fn new(accumulator: Box<dyn Accumulator>, creator: AggregateFunctionCreatorRef) -> Self {
121        Self {
122            accumulator,
123            creator,
124        }
125    }
126}
127
128impl DfAccumulator for DfAccumulatorAdaptor {
129    fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
130        let state_values = self.accumulator.state()?;
131        let state_types = self.creator.state_types()?;
132        if state_values.len() != state_types.len() {
133            return error::BadAccumulatorImplSnafu {
134                err_msg: format!("Accumulator {self:?} returned state values size do not match its state types size."),
135            }
136            .fail()?;
137        }
138        Ok(state_values
139            .into_iter()
140            .zip(state_types.iter())
141            .map(|(v, t)| v.try_to_scalar_value(t).context(error::ToScalarValueSnafu))
142            .collect::<Result<Vec<_>>>()?)
143    }
144
145    fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
146        let vectors = VectorHelper::try_into_vectors(values).context(FromScalarValueSnafu)?;
147        self.accumulator.update_batch(&vectors)?;
148        Ok(())
149    }
150
151    fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
152        let mut vectors = Vec::with_capacity(states.len());
153        for array in states.iter() {
154            vectors.push(
155                VectorHelper::try_into_vector(array).context(IntoVectorSnafu {
156                    data_type: array.data_type().clone(),
157                })?,
158            );
159        }
160        self.accumulator.merge_batch(&vectors)?;
161        Ok(())
162    }
163
164    fn evaluate(&mut self) -> DfResult<ScalarValue> {
165        let value = self.accumulator.evaluate()?;
166        let output_type = self.creator.output_type()?;
167        let scalar_value = value
168            .try_to_scalar_value(&output_type)
169            .context(error::ToScalarValueSnafu)?;
170        Ok(scalar_value)
171    }
172
173    fn size(&self) -> usize {
174        0
175    }
176}