common_function/scalars/date/
date_sub.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
18use datafusion_expr::Signature;
19use datatypes::arrow::compute::kernels::numeric;
20use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
21use datatypes::vectors::{Helper, VectorRef};
22use snafu::{ResultExt, ensure};
23
24use crate::function::{Function, FunctionContext};
25use crate::helper;
26
27/// A function subtracts an interval value to Timestamp, Date, and return the result.
28/// The implementation of datetime type is based on Date64 which is incorrect so this function
29/// doesn't support the datetime type.
30#[derive(Clone, Debug, Default)]
31pub struct DateSubFunction;
32
33const NAME: &str = "date_sub";
34
35impl Function for DateSubFunction {
36    fn name(&self) -> &str {
37        NAME
38    }
39
40    fn return_type(&self, input_types: &[DataType]) -> Result<DataType> {
41        Ok(input_types[0].clone())
42    }
43
44    fn signature(&self) -> Signature {
45        helper::one_of_sigs2(
46            vec![
47                DataType::Date32,
48                DataType::Timestamp(TimeUnit::Second, None),
49                DataType::Timestamp(TimeUnit::Millisecond, None),
50                DataType::Timestamp(TimeUnit::Microsecond, None),
51                DataType::Timestamp(TimeUnit::Nanosecond, None),
52            ],
53            vec![
54                DataType::Interval(IntervalUnit::MonthDayNano),
55                DataType::Interval(IntervalUnit::YearMonth),
56                DataType::Interval(IntervalUnit::DayTime),
57            ],
58        )
59    }
60
61    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
62        ensure!(
63            columns.len() == 2,
64            InvalidFuncArgsSnafu {
65                err_msg: format!(
66                    "The length of the args is not correct, expect 2, have: {}",
67                    columns.len()
68                ),
69            }
70        );
71
72        let left = columns[0].to_arrow_array();
73        let right = columns[1].to_arrow_array();
74
75        let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
76        let arrow_type = result.data_type().clone();
77        Helper::try_into_vector(result).context(IntoVectorSnafu {
78            data_type: arrow_type,
79        })
80    }
81}
82
83impl fmt::Display for DateSubFunction {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        write!(f, "DATE_SUB")
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use datafusion_expr::{TypeSignature, Volatility};
94    use datatypes::arrow::datatypes::IntervalDayTime;
95    use datatypes::value::Value;
96    use datatypes::vectors::{
97        DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
98    };
99
100    use super::{DateSubFunction, *};
101
102    #[test]
103    fn test_date_sub_misc() {
104        let f = DateSubFunction;
105        assert_eq!("date_sub", f.name());
106        assert_eq!(
107            DataType::Timestamp(TimeUnit::Microsecond, None),
108            f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
109                .unwrap()
110        );
111        assert_eq!(
112            DataType::Timestamp(TimeUnit::Second, None),
113            f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
114                .unwrap()
115        );
116        assert_eq!(
117            DataType::Date32,
118            f.return_type(&[DataType::Date32]).unwrap()
119        );
120        assert!(
121            matches!(f.signature(),
122                         Signature {
123                             type_signature: TypeSignature::OneOf(sigs),
124                             volatility: Volatility::Immutable
125                         } if  sigs.len() == 15),
126            "{:?}",
127            f.signature()
128        );
129    }
130
131    #[test]
132    fn test_timestamp_date_sub() {
133        let f = DateSubFunction;
134
135        let times = vec![Some(123), None, Some(42), None];
136        // Intervals in milliseconds
137        let intervals = vec![
138            IntervalDayTime::new(0, 1000),
139            IntervalDayTime::new(0, 2000),
140            IntervalDayTime::new(0, 3000),
141            IntervalDayTime::new(0, 1000),
142        ];
143        let results = [Some(122), None, Some(39), None];
144
145        let time_vector = TimestampSecondVector::from(times.clone());
146        let interval_vector = IntervalDayTimeVector::from_vec(intervals);
147        let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
148        let vector = f.eval(&FunctionContext::default(), &args).unwrap();
149
150        assert_eq!(4, vector.len());
151        for (i, _t) in times.iter().enumerate() {
152            let v = vector.get(i);
153            let result = results.get(i).unwrap();
154
155            if result.is_none() {
156                assert_eq!(Value::Null, v);
157                continue;
158            }
159            match v {
160                Value::Timestamp(ts) => {
161                    assert_eq!(ts.value(), result.unwrap());
162                }
163                _ => unreachable!(),
164            }
165        }
166    }
167
168    #[test]
169    fn test_date_date_sub() {
170        let f = DateSubFunction;
171        let days_per_month = 30;
172
173        let dates = vec![
174            Some(123 * days_per_month),
175            None,
176            Some(42 * days_per_month),
177            None,
178        ];
179        // Intervals in months
180        let intervals = vec![1, 2, 3, 1];
181        let results = [Some(3659), None, Some(1168), None];
182
183        let date_vector = DateVector::from(dates.clone());
184        let interval_vector = IntervalYearMonthVector::from_vec(intervals);
185        let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
186        let vector = f.eval(&FunctionContext::default(), &args).unwrap();
187
188        assert_eq!(4, vector.len());
189        for (i, _t) in dates.iter().enumerate() {
190            let v = vector.get(i);
191            let result = results.get(i).unwrap();
192
193            if result.is_none() {
194                assert_eq!(Value::Null, v);
195                continue;
196            }
197            match v {
198                Value::Date(date) => {
199                    assert_eq!(date.val(), result.unwrap());
200                }
201                _ => unreachable!(),
202            }
203        }
204    }
205}