common_function/scalars/date/
date_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error::ArrowComputeSnafu;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::compute::kernels::numeric;
21use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use snafu::ResultExt;
23
24use crate::function::{Function, extract_args};
25use crate::helper;
26
27/// A function adds an interval value to Timestamp, Date, and return the result.
28/// The implementation of datetime type is based on Date64 which is incorrect so this function
29/// doesn't support the datetime type.
30#[derive(Clone, Debug)]
31pub(crate) struct DateAddFunction {
32    signature: Signature,
33}
34
35impl Default for DateAddFunction {
36    fn default() -> Self {
37        Self {
38            signature: helper::one_of_sigs2(
39                vec![
40                    DataType::Date32,
41                    DataType::Timestamp(TimeUnit::Second, None),
42                    DataType::Timestamp(TimeUnit::Millisecond, None),
43                    DataType::Timestamp(TimeUnit::Microsecond, None),
44                    DataType::Timestamp(TimeUnit::Nanosecond, None),
45                ],
46                vec![
47                    DataType::Interval(IntervalUnit::MonthDayNano),
48                    DataType::Interval(IntervalUnit::YearMonth),
49                    DataType::Interval(IntervalUnit::DayTime),
50                ],
51            ),
52        }
53    }
54}
55
56const NAME: &str = "date_add";
57
58impl Function for DateAddFunction {
59    fn name(&self) -> &str {
60        NAME
61    }
62
63    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
64        Ok(input_types[0].clone())
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn invoke_with_args(
72        &self,
73        args: ScalarFunctionArgs,
74    ) -> datafusion_common::Result<ColumnarValue> {
75        let [left, right] = extract_args(self.name(), &args)?;
76
77        let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
78        Ok(ColumnarValue::Array(result))
79    }
80}
81
82impl fmt::Display for DateAddFunction {
83    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        write!(f, "DATE_ADD")
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow_schema::Field;
93    use datafusion::arrow::array::{
94        Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
95        TimestampSecondArray,
96    };
97    use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
98    use datafusion_common::config::ConfigOptions;
99    use datafusion_expr::{TypeSignature, Volatility};
100
101    use super::{DateAddFunction, *};
102
103    #[test]
104    fn test_date_add_misc() {
105        let f = DateAddFunction::default();
106        assert_eq!("date_add", f.name());
107        assert_eq!(
108            DataType::Timestamp(TimeUnit::Microsecond, None),
109            f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
110                .unwrap()
111        );
112        assert_eq!(
113            DataType::Timestamp(TimeUnit::Second, None),
114            f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
115                .unwrap()
116        );
117        assert_eq!(
118            DataType::Date32,
119            f.return_type(&[DataType::Date32]).unwrap()
120        );
121        assert!(
122            matches!(f.signature(),
123                         Signature {
124                             type_signature: TypeSignature::OneOf(sigs),
125                             volatility: Volatility::Immutable,
126                             ..
127                         } if  sigs.len() == 15),
128            "{:?}",
129            f.signature()
130        );
131    }
132
133    #[test]
134    fn test_timestamp_date_add() {
135        let f = DateAddFunction::default();
136
137        let times = vec![Some(123), None, Some(42), None];
138        // Intervals in milliseconds
139        let intervals = vec![
140            IntervalDayTime::new(0, 1000),
141            IntervalDayTime::new(0, 2000),
142            IntervalDayTime::new(0, 3000),
143            IntervalDayTime::new(0, 1000),
144        ];
145        let results = [Some(124), None, Some(45), None];
146
147        let args = vec![
148            ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
149            ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
150        ];
151
152        let vector = f
153            .invoke_with_args(ScalarFunctionArgs {
154                args,
155                arg_fields: vec![],
156                number_rows: 4,
157                return_field: Arc::new(Field::new(
158                    "x",
159                    DataType::Timestamp(TimeUnit::Second, None),
160                    true,
161                )),
162                config_options: Arc::new(ConfigOptions::new()),
163            })
164            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
165            .map(|mut a| a.remove(0))
166            .unwrap();
167        let vector = vector.as_primitive::<TimestampSecondType>();
168
169        assert_eq!(4, vector.len());
170        for (i, _t) in times.iter().enumerate() {
171            let result = results.get(i).unwrap();
172
173            if let Some(x) = result {
174                assert!(vector.is_valid(i));
175                assert_eq!(vector.value(i), *x);
176            } else {
177                assert!(vector.is_null(i));
178            }
179        }
180    }
181
182    #[test]
183    fn test_date_date_add() {
184        let f = DateAddFunction::default();
185
186        let dates = vec![Some(123), None, Some(42), None];
187        // Intervals in months
188        let intervals = vec![1, 2, 3, 1];
189        let results = [Some(154), None, Some(131), None];
190
191        let args = vec![
192            ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
193            ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
194        ];
195
196        let vector = f
197            .invoke_with_args(ScalarFunctionArgs {
198                args,
199                arg_fields: vec![],
200                number_rows: 4,
201                return_field: Arc::new(Field::new(
202                    "x",
203                    DataType::Timestamp(TimeUnit::Second, None),
204                    true,
205                )),
206                config_options: Arc::new(ConfigOptions::new()),
207            })
208            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
209            .map(|mut a| a.remove(0))
210            .unwrap();
211        let vector = vector.as_primitive::<Date32Type>();
212
213        assert_eq!(4, vector.len());
214        for (i, _t) in dates.iter().enumerate() {
215            let result = results.get(i).unwrap();
216
217            if let Some(x) = result {
218                assert!(vector.is_valid(i));
219                assert_eq!(vector.value(i), *x);
220            } else {
221                assert!(vector.is_null(i));
222            }
223        }
224    }
225}