common_function/scalars/date/
date_sub.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error::ArrowComputeSnafu;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::compute::kernels::numeric;
21use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use snafu::ResultExt;
23
24use crate::function::{Function, extract_args};
25use crate::helper;
26
27/// A function subtracts an interval value to Timestamp, Date, and return the result.
28/// The implementation of datetime type is based on Date64 which is incorrect so this function
29/// doesn't support the datetime type.
30#[derive(Clone, Debug)]
31pub(crate) struct DateSubFunction {
32    signature: Signature,
33}
34
35impl Default for DateSubFunction {
36    fn default() -> Self {
37        Self {
38            signature: helper::one_of_sigs2(
39                vec![
40                    DataType::Date32,
41                    DataType::Timestamp(TimeUnit::Second, None),
42                    DataType::Timestamp(TimeUnit::Millisecond, None),
43                    DataType::Timestamp(TimeUnit::Microsecond, None),
44                    DataType::Timestamp(TimeUnit::Nanosecond, None),
45                ],
46                vec![
47                    DataType::Interval(IntervalUnit::MonthDayNano),
48                    DataType::Interval(IntervalUnit::YearMonth),
49                    DataType::Interval(IntervalUnit::DayTime),
50                ],
51            ),
52        }
53    }
54}
55
56impl Function for DateSubFunction {
57    fn name(&self) -> &str {
58        "date_sub"
59    }
60
61    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
62        Ok(input_types[0].clone())
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn invoke_with_args(
70        &self,
71        args: ScalarFunctionArgs,
72    ) -> datafusion_common::Result<ColumnarValue> {
73        let [left, right] = extract_args(self.name(), &args)?;
74
75        let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
76        Ok(ColumnarValue::Array(result))
77    }
78}
79
80impl fmt::Display for DateSubFunction {
81    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82        write!(f, "DATE_SUB")
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::Arc;
89
90    use arrow_schema::Field;
91    use datafusion::arrow::array::{
92        Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
93        TimestampSecondArray,
94    };
95    use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
96    use datafusion_common::config::ConfigOptions;
97    use datafusion_expr::{TypeSignature, Volatility};
98
99    use super::{DateSubFunction, *};
100
101    #[test]
102    fn test_date_sub_misc() {
103        let f = DateSubFunction::default();
104        assert_eq!("date_sub", f.name());
105        assert_eq!(
106            DataType::Timestamp(TimeUnit::Microsecond, None),
107            f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
108                .unwrap()
109        );
110        assert_eq!(
111            DataType::Timestamp(TimeUnit::Second, None),
112            f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
113                .unwrap()
114        );
115        assert_eq!(
116            DataType::Date32,
117            f.return_type(&[DataType::Date32]).unwrap()
118        );
119        assert!(
120            matches!(f.signature(),
121                         Signature {
122                             type_signature: TypeSignature::OneOf(sigs),
123                             volatility: Volatility::Immutable,
124                             ..
125                         } if  sigs.len() == 15),
126            "{:?}",
127            f.signature()
128        );
129    }
130
131    #[test]
132    fn test_timestamp_date_sub() {
133        let f = DateSubFunction::default();
134
135        let times = vec![Some(123), None, Some(42), None];
136        // Intervals in milliseconds
137        let intervals = vec![
138            IntervalDayTime::new(0, 1000),
139            IntervalDayTime::new(0, 2000),
140            IntervalDayTime::new(0, 3000),
141            IntervalDayTime::new(0, 1000),
142        ];
143        let results = [Some(122), None, Some(39), None];
144
145        let args = vec![
146            ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
147            ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
148        ];
149
150        let vector = f
151            .invoke_with_args(ScalarFunctionArgs {
152                args,
153                arg_fields: vec![],
154                number_rows: 4,
155                return_field: Arc::new(Field::new(
156                    "x",
157                    DataType::Timestamp(TimeUnit::Second, None),
158                    true,
159                )),
160                config_options: Arc::new(ConfigOptions::new()),
161            })
162            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
163            .map(|mut a| a.remove(0))
164            .unwrap();
165        let vector = vector.as_primitive::<TimestampSecondType>();
166
167        assert_eq!(4, vector.len());
168        for (i, _t) in times.iter().enumerate() {
169            let result = results.get(i).unwrap();
170
171            if let Some(x) = result {
172                assert!(vector.is_valid(i));
173                assert_eq!(vector.value(i), *x);
174            } else {
175                assert!(vector.is_null(i));
176            }
177        }
178    }
179
180    #[test]
181    fn test_date_date_sub() {
182        let f = DateSubFunction::default();
183        let days_per_month = 30;
184
185        let dates = vec![
186            Some(123 * days_per_month),
187            None,
188            Some(42 * days_per_month),
189            None,
190        ];
191        // Intervals in months
192        let intervals = vec![1, 2, 3, 1];
193        let results = [Some(3659), None, Some(1168), None];
194
195        let args = vec![
196            ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
197            ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
198        ];
199
200        let vector = f
201            .invoke_with_args(ScalarFunctionArgs {
202                args,
203                arg_fields: vec![],
204                number_rows: 4,
205                return_field: Arc::new(Field::new(
206                    "x",
207                    DataType::Timestamp(TimeUnit::Second, None),
208                    true,
209                )),
210                config_options: Arc::new(ConfigOptions::new()),
211            })
212            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
213            .map(|mut a| a.remove(0))
214            .unwrap();
215        let vector = vector.as_primitive::<Date32Type>();
216
217        assert_eq!(4, vector.len());
218        for (i, _t) in dates.iter().enumerate() {
219            let result = results.get(i).unwrap();
220
221            if let Some(x) = result {
222                assert!(vector.is_valid(i));
223                assert_eq!(vector.value(i), *x);
224            } else {
225                assert!(vector.is_null(i));
226            }
227        }
228    }
229}