1pub mod accumulator;
16mod expr;
17mod udaf;
18
19use std::sync::Arc;
20
21use api::v1::TableName;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::error::Result as DatafusionResult;
24use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
25use datafusion_common::{Column, TableReference};
26use datafusion_expr::dml::InsertOp;
27use datafusion_expr::{col, DmlStatement, WriteOp};
28pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter};
29use snafu::ResultExt;
30
31pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
32pub use self::udaf::AggregateFunction;
33use crate::error::{GeneralDataFusionSnafu, Result};
34use crate::logical_plan::accumulator::*;
35use crate::signature::{Signature, Volatility};
36
37pub fn create_aggregate_function(
38 name: String,
39 args_count: u8,
40 creator: Arc<dyn AggregateFunctionCreator>,
41) -> AggregateFunction {
42 let return_type = make_return_function(creator.clone());
43 let accumulator = make_accumulator_function(creator.clone());
44 let state_type = make_state_function(creator.clone());
45 AggregateFunction::new(
46 name,
47 Signature::any(args_count as usize, Volatility::Immutable),
48 return_type,
49 accumulator,
50 state_type,
51 creator,
52 )
53}
54
55pub fn rename_logical_plan_columns(
59 enable_ident_normalization: bool,
60 plan: LogicalPlan,
61 renames: Vec<(&str, &str)>,
62) -> DatafusionResult<LogicalPlan> {
63 let mut projection = Vec::with_capacity(renames.len());
64
65 for (old_name, new_name) in renames {
66 let old_column: Column = if enable_ident_normalization {
67 Column::from_qualified_name(old_name)
68 } else {
69 Column::from_qualified_name_ignore_case(old_name)
70 };
71
72 let (qualifier_rename, field_rename) =
73 plan.schema().qualified_field_from_column(&old_column)?;
74
75 for (qualifier, field) in plan.schema().iter() {
76 if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename {
77 projection.push(col(Column::from((qualifier, field))).alias(new_name));
78 }
79 }
80 }
81
82 LogicalPlanBuilder::from(plan).project(projection)?.build()
83}
84
85pub fn breakup_insert_plan(
94 plan: &LogicalPlan,
95 default_catalog: &str,
96 default_schema: &str,
97) -> Option<(TableName, Arc<LogicalPlan>)> {
98 if let LogicalPlan::Dml(dml) = plan {
99 if dml.op != WriteOp::Insert(InsertOp::Append) {
100 return None;
101 }
102 let table_name = &dml.table_name;
103 let table_name = match table_name {
104 TableReference::Bare { table } => TableName {
105 catalog_name: default_catalog.to_string(),
106 schema_name: default_schema.to_string(),
107 table_name: table.to_string(),
108 },
109 TableReference::Partial { schema, table } => TableName {
110 catalog_name: default_catalog.to_string(),
111 schema_name: schema.to_string(),
112 table_name: table.to_string(),
113 },
114 TableReference::Full {
115 catalog,
116 schema,
117 table,
118 } => TableName {
119 catalog_name: catalog.to_string(),
120 schema_name: schema.to_string(),
121 table_name: table.to_string(),
122 },
123 };
124 let logical_plan = dml.input.clone();
125 Some((table_name, logical_plan))
126 } else {
127 None
128 }
129}
130
131pub fn add_insert_to_logical_plan(
133 table_name: TableName,
134 table_schema: datafusion_common::DFSchemaRef,
135 input: LogicalPlan,
136) -> Result<LogicalPlan> {
137 let table_name = TableReference::Full {
138 catalog: table_name.catalog_name.into(),
139 schema: table_name.schema_name.into(),
140 table: table_name.table_name.into(),
141 };
142
143 let plan = LogicalPlan::Dml(DmlStatement::new(
144 table_name,
145 table_schema,
146 WriteOp::Insert(InsertOp::Append),
147 Arc::new(input),
148 ));
149 let plan = plan.recompute_schema().context(GeneralDataFusionSnafu)?;
150 Ok(plan)
151}
152
153#[async_trait::async_trait]
155pub trait SubstraitPlanDecoder {
156 async fn decode(
163 &self,
164 message: bytes::Bytes,
165 catalog_list: Arc<dyn CatalogProviderList>,
166 optimize: bool,
167 ) -> Result<LogicalPlan>;
168}
169
170pub type SubstraitPlanDecoderRef = Arc<dyn SubstraitPlanDecoder + Send + Sync>;
171
172#[cfg(test)]
173mod tests {
174 use std::sync::Arc;
175
176 use datafusion_expr::builder::LogicalTableSource;
177 use datafusion_expr::lit;
178 use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
179 use datatypes::prelude::*;
180 use datatypes::vectors::VectorRef;
181
182 use super::*;
183 use crate::error::Result;
184 use crate::function::AccumulatorCreatorFunction;
185 use crate::signature::TypeSignature;
186
187 #[derive(Debug)]
188 struct DummyAccumulator;
189
190 impl Accumulator for DummyAccumulator {
191 fn state(&self) -> Result<Vec<Value>> {
192 Ok(vec![])
193 }
194
195 fn update_batch(&mut self, _values: &[VectorRef]) -> Result<()> {
196 Ok(())
197 }
198
199 fn merge_batch(&mut self, _states: &[VectorRef]) -> Result<()> {
200 Ok(())
201 }
202
203 fn evaluate(&self) -> Result<Value> {
204 Ok(Value::Int32(0))
205 }
206 }
207
208 #[derive(Debug)]
209 struct DummyAccumulatorCreator;
210
211 impl AggrFuncTypeStore for DummyAccumulatorCreator {
212 fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
213 Ok(vec![ConcreteDataType::float64_datatype()])
214 }
215
216 fn set_input_types(&self, _: Vec<ConcreteDataType>) -> Result<()> {
217 Ok(())
218 }
219 }
220
221 impl AggregateFunctionCreator for DummyAccumulatorCreator {
222 fn creator(&self) -> AccumulatorCreatorFunction {
223 Arc::new(|_| Ok(Box::new(DummyAccumulator)))
224 }
225
226 fn output_type(&self) -> Result<ConcreteDataType> {
227 Ok(self.input_types()?.into_iter().next().unwrap())
228 }
229
230 fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
231 Ok(vec![
232 ConcreteDataType::float64_datatype(),
233 ConcreteDataType::uint32_datatype(),
234 ])
235 }
236 }
237
238 fn mock_plan() -> LogicalPlan {
239 let schema = Schema::new(vec![
240 Field::new("id", DataType::Int32, true),
241 Field::new("name", DataType::Utf8, true),
242 ]);
243 let table_source = LogicalTableSource::new(SchemaRef::new(schema));
244
245 let projection = None;
246
247 let builder =
248 LogicalPlanBuilder::scan("person", Arc::new(table_source), projection).unwrap();
249
250 builder
251 .filter(col("id").gt(lit(500)))
252 .unwrap()
253 .build()
254 .unwrap()
255 }
256
257 #[test]
258 fn test_rename_logical_plan_columns() {
259 let plan = mock_plan();
260 let new_plan =
261 rename_logical_plan_columns(true, plan, vec![("id", "a"), ("name", "b")]).unwrap();
262
263 assert_eq!(
264 r#"
265Projection: person.id AS a, person.name AS b
266 Filter: person.id > Int32(500)
267 TableScan: person"#,
268 format!("\n{}", new_plan)
269 );
270 }
271
272 #[test]
273 fn test_create_udaf() {
274 let creator = DummyAccumulatorCreator;
275 let udaf = create_aggregate_function("dummy".to_string(), 1, Arc::new(creator));
276 assert_eq!("dummy", udaf.name);
277
278 let signature = udaf.signature;
279 assert_eq!(TypeSignature::Any(1), signature.type_signature);
280 assert_eq!(Volatility::Immutable, signature.volatility);
281
282 assert_eq!(
283 Arc::new(ConcreteDataType::float64_datatype()),
284 (udaf.return_type)(&[ConcreteDataType::float64_datatype()]).unwrap()
285 );
286 assert_eq!(
287 Arc::new(vec![
288 ConcreteDataType::float64_datatype(),
289 ConcreteDataType::uint32_datatype(),
290 ]),
291 (udaf.state_type)(&ConcreteDataType::float64_datatype()).unwrap()
292 );
293 }
294}