common_query/logical_plan/
accumulator.rs1use std::fmt::Debug;
18use std::sync::Arc;
19
20use datafusion_common::Result as DfResult;
21use datafusion_expr::Accumulator as DfAccumulator;
22use datatypes::arrow::array::ArrayRef;
23use datatypes::prelude::*;
24use datatypes::vectors::{Helper as VectorHelper, VectorRef};
25use snafu::ResultExt;
26
27use crate::error::{self, FromScalarValueSnafu, IntoVectorSnafu, Result};
28use crate::prelude::*;
29
30pub type AggregateFunctionCreatorRef = Arc<dyn AggregateFunctionCreator>;
31
32pub trait Accumulator: Send + Sync + Debug {
43 fn state(&self) -> Result<Vec<Value>>;
47
48 fn update_batch(&mut self, values: &[VectorRef]) -> Result<()>;
50
51 fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()>;
53
54 fn evaluate(&self) -> Result<Value>;
56}
57
58pub trait AggregateFunctionCreator: AggrFuncTypeStore {
64 fn creator(&self) -> AccumulatorCreatorFunction;
66
67 fn output_type(&self) -> Result<ConcreteDataType>;
69
70 fn state_types(&self) -> Result<Vec<ConcreteDataType>>;
72}
73
74pub trait AggrFuncTypeStore: Send + Sync + Debug {
80 fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
82
83 fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
86}
87
88pub fn make_accumulator_function(
89 creator: Arc<dyn AggregateFunctionCreator>,
90) -> AccumulatorFunctionImpl {
91 Arc::new(move || {
92 let input_types = creator.input_types()?;
93 let creator = creator.creator();
94 creator(&input_types)
95 })
96}
97
98pub fn make_return_function(creator: Arc<dyn AggregateFunctionCreator>) -> ReturnTypeFunction {
99 Arc::new(move |input_types| {
100 creator.set_input_types(input_types.to_vec())?;
101
102 let output_type = creator.output_type()?;
103 Ok(Arc::new(output_type))
104 })
105}
106
107pub fn make_state_function(creator: Arc<dyn AggregateFunctionCreator>) -> StateTypeFunction {
108 Arc::new(move |_| Ok(Arc::new(creator.state_types()?)))
109}
110
111#[derive(Debug)]
114pub struct DfAccumulatorAdaptor {
115 accumulator: Box<dyn Accumulator>,
116 creator: AggregateFunctionCreatorRef,
117}
118
119impl DfAccumulatorAdaptor {
120 pub fn new(accumulator: Box<dyn Accumulator>, creator: AggregateFunctionCreatorRef) -> Self {
121 Self {
122 accumulator,
123 creator,
124 }
125 }
126}
127
128impl DfAccumulator for DfAccumulatorAdaptor {
129 fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
130 let state_values = self.accumulator.state()?;
131 let state_types = self.creator.state_types()?;
132 if state_values.len() != state_types.len() {
133 return error::BadAccumulatorImplSnafu {
134 err_msg: format!("Accumulator {self:?} returned state values size do not match its state types size."),
135 }
136 .fail()?;
137 }
138 Ok(state_values
139 .into_iter()
140 .zip(state_types.iter())
141 .map(|(v, t)| v.try_to_scalar_value(t).context(error::ToScalarValueSnafu))
142 .collect::<Result<Vec<_>>>()?)
143 }
144
145 fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
146 let vectors = VectorHelper::try_into_vectors(values).context(FromScalarValueSnafu)?;
147 self.accumulator.update_batch(&vectors)?;
148 Ok(())
149 }
150
151 fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
152 let mut vectors = Vec::with_capacity(states.len());
153 for array in states.iter() {
154 vectors.push(
155 VectorHelper::try_into_vector(array).context(IntoVectorSnafu {
156 data_type: array.data_type().clone(),
157 })?,
158 );
159 }
160 self.accumulator.merge_batch(&vectors)?;
161 Ok(())
162 }
163
164 fn evaluate(&mut self) -> DfResult<ScalarValue> {
165 let value = self.accumulator.evaluate()?;
166 let output_type = self.creator.output_type()?;
167 let scalar_value = value
168 .try_to_scalar_value(&output_type)
169 .context(error::ToScalarValueSnafu)?;
170 Ok(scalar_value)
171 }
172
173 fn size(&self) -> usize {
174 0
175 }
176}