common_query/
logical_plan.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod accumulator;
16mod expr;
17mod udaf;
18
19use std::sync::Arc;
20
21use api::v1::TableName;
22use datafusion::catalog::CatalogProviderList;
23use datafusion::error::Result as DatafusionResult;
24use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
25use datafusion_common::{Column, TableReference};
26use datafusion_expr::dml::InsertOp;
27use datafusion_expr::{col, DmlStatement, WriteOp};
28pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter};
29use snafu::ResultExt;
30
31pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
32pub use self::udaf::AggregateFunction;
33use crate::error::{GeneralDataFusionSnafu, Result};
34use crate::logical_plan::accumulator::*;
35use crate::signature::{Signature, Volatility};
36
37pub fn create_aggregate_function(
38    name: String,
39    args_count: u8,
40    creator: Arc<dyn AggregateFunctionCreator>,
41) -> AggregateFunction {
42    let return_type = make_return_function(creator.clone());
43    let accumulator = make_accumulator_function(creator.clone());
44    let state_type = make_state_function(creator.clone());
45    AggregateFunction::new(
46        name,
47        Signature::any(args_count as usize, Volatility::Immutable),
48        return_type,
49        accumulator,
50        state_type,
51        creator,
52    )
53}
54
55/// Rename columns by applying a new projection. Returns an error if the column to be
56/// renamed does not exist. The `renames` parameter is a `Vector` with elements
57/// in the form of `(old_name, new_name)`.
58pub fn rename_logical_plan_columns(
59    enable_ident_normalization: bool,
60    plan: LogicalPlan,
61    renames: Vec<(&str, &str)>,
62) -> DatafusionResult<LogicalPlan> {
63    let mut projection = Vec::with_capacity(renames.len());
64
65    for (old_name, new_name) in renames {
66        let old_column: Column = if enable_ident_normalization {
67            Column::from_qualified_name(old_name)
68        } else {
69            Column::from_qualified_name_ignore_case(old_name)
70        };
71
72        let (qualifier_rename, field_rename) =
73            plan.schema().qualified_field_from_column(&old_column)?;
74
75        for (qualifier, field) in plan.schema().iter() {
76            if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename {
77                projection.push(col(Column::from((qualifier, field))).alias(new_name));
78            }
79        }
80    }
81
82    LogicalPlanBuilder::from(plan).project(projection)?.build()
83}
84
85/// Convert a insert into logical plan to an (table_name, logical_plan)
86/// where table_name is the name of the table to insert into.
87/// logical_plan is the plan to be executed.
88///
89/// if input logical plan is not `insert into table_name <input>`, return None
90///
91/// Returned TableName will use provided catalog and schema if not specified in the logical plan,
92/// if table scan in logical plan have full table name, will **NOT** override it.
93pub fn breakup_insert_plan(
94    plan: &LogicalPlan,
95    default_catalog: &str,
96    default_schema: &str,
97) -> Option<(TableName, Arc<LogicalPlan>)> {
98    if let LogicalPlan::Dml(dml) = plan {
99        if dml.op != WriteOp::Insert(InsertOp::Append) {
100            return None;
101        }
102        let table_name = &dml.table_name;
103        let table_name = match table_name {
104            TableReference::Bare { table } => TableName {
105                catalog_name: default_catalog.to_string(),
106                schema_name: default_schema.to_string(),
107                table_name: table.to_string(),
108            },
109            TableReference::Partial { schema, table } => TableName {
110                catalog_name: default_catalog.to_string(),
111                schema_name: schema.to_string(),
112                table_name: table.to_string(),
113            },
114            TableReference::Full {
115                catalog,
116                schema,
117                table,
118            } => TableName {
119                catalog_name: catalog.to_string(),
120                schema_name: schema.to_string(),
121                table_name: table.to_string(),
122            },
123        };
124        let logical_plan = dml.input.clone();
125        Some((table_name, logical_plan))
126    } else {
127        None
128    }
129}
130
131/// create a `insert into table_name <input>` logical plan
132pub fn add_insert_to_logical_plan(
133    table_name: TableName,
134    table_schema: datafusion_common::DFSchemaRef,
135    input: LogicalPlan,
136) -> Result<LogicalPlan> {
137    let table_name = TableReference::Full {
138        catalog: table_name.catalog_name.into(),
139        schema: table_name.schema_name.into(),
140        table: table_name.table_name.into(),
141    };
142
143    let plan = LogicalPlan::Dml(DmlStatement::new(
144        table_name,
145        table_schema,
146        WriteOp::Insert(InsertOp::Append),
147        Arc::new(input),
148    ));
149    let plan = plan.recompute_schema().context(GeneralDataFusionSnafu)?;
150    Ok(plan)
151}
152
153/// The datafusion `[LogicalPlan]` decoder.
154#[async_trait::async_trait]
155pub trait SubstraitPlanDecoder {
156    /// Decode the [`LogicalPlan`] from bytes with the [`CatalogProviderList`].
157    /// When `optimize` is true, it will do the optimization for decoded plan.
158    ///
159    /// TODO(dennis): It's not a good design for an API to do many things.
160    /// The `optimize` was introduced because of `query` and `catalog` cyclic dependency issue
161    /// I am happy to refactor it if we have a better solution.
162    async fn decode(
163        &self,
164        message: bytes::Bytes,
165        catalog_list: Arc<dyn CatalogProviderList>,
166        optimize: bool,
167    ) -> Result<LogicalPlan>;
168}
169
170pub type SubstraitPlanDecoderRef = Arc<dyn SubstraitPlanDecoder + Send + Sync>;
171
172#[cfg(test)]
173mod tests {
174    use std::sync::Arc;
175
176    use datafusion_expr::builder::LogicalTableSource;
177    use datafusion_expr::lit;
178    use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
179    use datatypes::prelude::*;
180    use datatypes::vectors::VectorRef;
181
182    use super::*;
183    use crate::error::Result;
184    use crate::function::AccumulatorCreatorFunction;
185    use crate::signature::TypeSignature;
186
187    #[derive(Debug)]
188    struct DummyAccumulator;
189
190    impl Accumulator for DummyAccumulator {
191        fn state(&self) -> Result<Vec<Value>> {
192            Ok(vec![])
193        }
194
195        fn update_batch(&mut self, _values: &[VectorRef]) -> Result<()> {
196            Ok(())
197        }
198
199        fn merge_batch(&mut self, _states: &[VectorRef]) -> Result<()> {
200            Ok(())
201        }
202
203        fn evaluate(&self) -> Result<Value> {
204            Ok(Value::Int32(0))
205        }
206    }
207
208    #[derive(Debug)]
209    struct DummyAccumulatorCreator;
210
211    impl AggrFuncTypeStore for DummyAccumulatorCreator {
212        fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
213            Ok(vec![ConcreteDataType::float64_datatype()])
214        }
215
216        fn set_input_types(&self, _: Vec<ConcreteDataType>) -> Result<()> {
217            Ok(())
218        }
219    }
220
221    impl AggregateFunctionCreator for DummyAccumulatorCreator {
222        fn creator(&self) -> AccumulatorCreatorFunction {
223            Arc::new(|_| Ok(Box::new(DummyAccumulator)))
224        }
225
226        fn output_type(&self) -> Result<ConcreteDataType> {
227            Ok(self.input_types()?.into_iter().next().unwrap())
228        }
229
230        fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
231            Ok(vec![
232                ConcreteDataType::float64_datatype(),
233                ConcreteDataType::uint32_datatype(),
234            ])
235        }
236    }
237
238    fn mock_plan() -> LogicalPlan {
239        let schema = Schema::new(vec![
240            Field::new("id", DataType::Int32, true),
241            Field::new("name", DataType::Utf8, true),
242        ]);
243        let table_source = LogicalTableSource::new(SchemaRef::new(schema));
244
245        let projection = None;
246
247        let builder =
248            LogicalPlanBuilder::scan("person", Arc::new(table_source), projection).unwrap();
249
250        builder
251            .filter(col("id").gt(lit(500)))
252            .unwrap()
253            .build()
254            .unwrap()
255    }
256
257    #[test]
258    fn test_rename_logical_plan_columns() {
259        let plan = mock_plan();
260        let new_plan =
261            rename_logical_plan_columns(true, plan, vec![("id", "a"), ("name", "b")]).unwrap();
262
263        assert_eq!(
264            r#"
265Projection: person.id AS a, person.name AS b
266  Filter: person.id > Int32(500)
267    TableScan: person"#,
268            format!("\n{}", new_plan)
269        );
270    }
271
272    #[test]
273    fn test_create_udaf() {
274        let creator = DummyAccumulatorCreator;
275        let udaf = create_aggregate_function("dummy".to_string(), 1, Arc::new(creator));
276        assert_eq!("dummy", udaf.name);
277
278        let signature = udaf.signature;
279        assert_eq!(TypeSignature::Any(1), signature.type_signature);
280        assert_eq!(Volatility::Immutable, signature.volatility);
281
282        assert_eq!(
283            Arc::new(ConcreteDataType::float64_datatype()),
284            (udaf.return_type)(&[ConcreteDataType::float64_datatype()]).unwrap()
285        );
286        assert_eq!(
287            Arc::new(vec![
288                ConcreteDataType::float64_datatype(),
289                ConcreteDataType::uint32_datatype(),
290            ]),
291            (udaf.state_type)(&ConcreteDataType::float64_datatype()).unwrap()
292        );
293    }
294}