common_function/scalars/date/
date_add.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error::ArrowComputeSnafu;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::compute::kernels::numeric;
21use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use snafu::ResultExt;
23
24use crate::function::{Function, extract_args};
25use crate::helper;
26
27/// A function adds an interval value to Timestamp, Date, and return the result.
28/// The implementation of datetime type is based on Date64 which is incorrect so this function
29/// doesn't support the datetime type.
30#[derive(Clone, Debug)]
31pub(crate) struct DateAddFunction {
32    signature: Signature,
33}
34
35impl Default for DateAddFunction {
36    fn default() -> Self {
37        Self {
38            signature: helper::one_of_sigs2(
39                vec![
40                    DataType::Date32,
41                    DataType::Timestamp(TimeUnit::Second, None),
42                    DataType::Timestamp(TimeUnit::Millisecond, None),
43                    DataType::Timestamp(TimeUnit::Microsecond, None),
44                    DataType::Timestamp(TimeUnit::Nanosecond, None),
45                ],
46                vec![
47                    DataType::Interval(IntervalUnit::MonthDayNano),
48                    DataType::Interval(IntervalUnit::YearMonth),
49                    DataType::Interval(IntervalUnit::DayTime),
50                ],
51            ),
52        }
53    }
54}
55
56const NAME: &str = "date_add";
57
58impl Function for DateAddFunction {
59    fn name(&self) -> &str {
60        NAME
61    }
62
63    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
64        Ok(input_types[0].clone())
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn invoke_with_args(
72        &self,
73        args: ScalarFunctionArgs,
74    ) -> datafusion_common::Result<ColumnarValue> {
75        let [left, right] = extract_args(self.name(), &args)?;
76
77        let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
78        Ok(ColumnarValue::Array(result))
79    }
80}
81
82impl fmt::Display for DateAddFunction {
83    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        write!(f, "DATE_ADD")
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow_schema::Field;
93    use datafusion::arrow::array::{
94        Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
95        TimestampSecondArray,
96    };
97    use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
98    use datafusion_common::config::ConfigOptions;
99    use datafusion_expr::{TypeSignature, Volatility};
100
101    use super::{DateAddFunction, *};
102
103    #[test]
104    fn test_date_add_misc() {
105        let f = DateAddFunction::default();
106        assert_eq!("date_add", f.name());
107        assert_eq!(
108            DataType::Timestamp(TimeUnit::Microsecond, None),
109            f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
110                .unwrap()
111        );
112        assert_eq!(
113            DataType::Timestamp(TimeUnit::Second, None),
114            f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
115                .unwrap()
116        );
117        assert_eq!(
118            DataType::Date32,
119            f.return_type(&[DataType::Date32]).unwrap()
120        );
121        assert!(
122            matches!(f.signature(),
123                         Signature {
124                             type_signature: TypeSignature::OneOf(sigs),
125                             volatility: Volatility::Immutable
126                         } if  sigs.len() == 15),
127            "{:?}",
128            f.signature()
129        );
130    }
131
132    #[test]
133    fn test_timestamp_date_add() {
134        let f = DateAddFunction::default();
135
136        let times = vec![Some(123), None, Some(42), None];
137        // Intervals in milliseconds
138        let intervals = vec![
139            IntervalDayTime::new(0, 1000),
140            IntervalDayTime::new(0, 2000),
141            IntervalDayTime::new(0, 3000),
142            IntervalDayTime::new(0, 1000),
143        ];
144        let results = [Some(124), None, Some(45), None];
145
146        let args = vec![
147            ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
148            ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
149        ];
150
151        let vector = f
152            .invoke_with_args(ScalarFunctionArgs {
153                args,
154                arg_fields: vec![],
155                number_rows: 4,
156                return_field: Arc::new(Field::new(
157                    "x",
158                    DataType::Timestamp(TimeUnit::Second, None),
159                    true,
160                )),
161                config_options: Arc::new(ConfigOptions::new()),
162            })
163            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
164            .map(|mut a| a.remove(0))
165            .unwrap();
166        let vector = vector.as_primitive::<TimestampSecondType>();
167
168        assert_eq!(4, vector.len());
169        for (i, _t) in times.iter().enumerate() {
170            let result = results.get(i).unwrap();
171
172            if let Some(x) = result {
173                assert!(vector.is_valid(i));
174                assert_eq!(vector.value(i), *x);
175            } else {
176                assert!(vector.is_null(i));
177            }
178        }
179    }
180
181    #[test]
182    fn test_date_date_add() {
183        let f = DateAddFunction::default();
184
185        let dates = vec![Some(123), None, Some(42), None];
186        // Intervals in months
187        let intervals = vec![1, 2, 3, 1];
188        let results = [Some(154), None, Some(131), None];
189
190        let args = vec![
191            ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
192            ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
193        ];
194
195        let vector = f
196            .invoke_with_args(ScalarFunctionArgs {
197                args,
198                arg_fields: vec![],
199                number_rows: 4,
200                return_field: Arc::new(Field::new(
201                    "x",
202                    DataType::Timestamp(TimeUnit::Second, None),
203                    true,
204                )),
205                config_options: Arc::new(ConfigOptions::new()),
206            })
207            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
208            .map(|mut a| a.remove(0))
209            .unwrap();
210        let vector = vector.as_primitive::<Date32Type>();
211
212        assert_eq!(4, vector.len());
213        for (i, _t) in dates.iter().enumerate() {
214            let result = results.get(i).unwrap();
215
216            if let Some(x) = result {
217                assert!(vector.is_valid(i));
218                assert_eq!(vector.value(i), *x);
219            } else {
220                assert!(vector.is_null(i));
221            }
222        }
223    }
224}