common_function/scalars/date/
date_sub.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use common_query::error::ArrowComputeSnafu;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::compute::kernels::numeric;
21use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use snafu::ResultExt;
23
24use crate::function::{Function, extract_args};
25use crate::helper;
26
27/// A function subtracts an interval value to Timestamp, Date, and return the result.
28/// The implementation of datetime type is based on Date64 which is incorrect so this function
29/// doesn't support the datetime type.
30#[derive(Clone, Debug)]
31pub(crate) struct DateSubFunction {
32    signature: Signature,
33}
34
35impl Default for DateSubFunction {
36    fn default() -> Self {
37        Self {
38            signature: helper::one_of_sigs2(
39                vec![
40                    DataType::Date32,
41                    DataType::Timestamp(TimeUnit::Second, None),
42                    DataType::Timestamp(TimeUnit::Millisecond, None),
43                    DataType::Timestamp(TimeUnit::Microsecond, None),
44                    DataType::Timestamp(TimeUnit::Nanosecond, None),
45                ],
46                vec![
47                    DataType::Interval(IntervalUnit::MonthDayNano),
48                    DataType::Interval(IntervalUnit::YearMonth),
49                    DataType::Interval(IntervalUnit::DayTime),
50                ],
51            ),
52        }
53    }
54}
55
56impl Function for DateSubFunction {
57    fn name(&self) -> &str {
58        "date_sub"
59    }
60
61    fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
62        Ok(input_types[0].clone())
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn invoke_with_args(
70        &self,
71        args: ScalarFunctionArgs,
72    ) -> datafusion_common::Result<ColumnarValue> {
73        let [left, right] = extract_args(self.name(), &args)?;
74
75        let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
76        Ok(ColumnarValue::Array(result))
77    }
78}
79
80impl fmt::Display for DateSubFunction {
81    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
82        write!(f, "DATE_SUB")
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::Arc;
89
90    use arrow_schema::Field;
91    use datafusion::arrow::array::{
92        Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
93        TimestampSecondArray,
94    };
95    use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
96    use datafusion_common::config::ConfigOptions;
97    use datafusion_expr::{TypeSignature, Volatility};
98
99    use super::{DateSubFunction, *};
100
101    #[test]
102    fn test_date_sub_misc() {
103        let f = DateSubFunction::default();
104        assert_eq!("date_sub", f.name());
105        assert_eq!(
106            DataType::Timestamp(TimeUnit::Microsecond, None),
107            f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
108                .unwrap()
109        );
110        assert_eq!(
111            DataType::Timestamp(TimeUnit::Second, None),
112            f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
113                .unwrap()
114        );
115        assert_eq!(
116            DataType::Date32,
117            f.return_type(&[DataType::Date32]).unwrap()
118        );
119        assert!(
120            matches!(f.signature(),
121                         Signature {
122                             type_signature: TypeSignature::OneOf(sigs),
123                             volatility: Volatility::Immutable
124                         } if  sigs.len() == 15),
125            "{:?}",
126            f.signature()
127        );
128    }
129
130    #[test]
131    fn test_timestamp_date_sub() {
132        let f = DateSubFunction::default();
133
134        let times = vec![Some(123), None, Some(42), None];
135        // Intervals in milliseconds
136        let intervals = vec![
137            IntervalDayTime::new(0, 1000),
138            IntervalDayTime::new(0, 2000),
139            IntervalDayTime::new(0, 3000),
140            IntervalDayTime::new(0, 1000),
141        ];
142        let results = [Some(122), None, Some(39), None];
143
144        let args = vec![
145            ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
146            ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
147        ];
148
149        let vector = f
150            .invoke_with_args(ScalarFunctionArgs {
151                args,
152                arg_fields: vec![],
153                number_rows: 4,
154                return_field: Arc::new(Field::new(
155                    "x",
156                    DataType::Timestamp(TimeUnit::Second, None),
157                    true,
158                )),
159                config_options: Arc::new(ConfigOptions::new()),
160            })
161            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
162            .map(|mut a| a.remove(0))
163            .unwrap();
164        let vector = vector.as_primitive::<TimestampSecondType>();
165
166        assert_eq!(4, vector.len());
167        for (i, _t) in times.iter().enumerate() {
168            let result = results.get(i).unwrap();
169
170            if let Some(x) = result {
171                assert!(vector.is_valid(i));
172                assert_eq!(vector.value(i), *x);
173            } else {
174                assert!(vector.is_null(i));
175            }
176        }
177    }
178
179    #[test]
180    fn test_date_date_sub() {
181        let f = DateSubFunction::default();
182        let days_per_month = 30;
183
184        let dates = vec![
185            Some(123 * days_per_month),
186            None,
187            Some(42 * days_per_month),
188            None,
189        ];
190        // Intervals in months
191        let intervals = vec![1, 2, 3, 1];
192        let results = [Some(3659), None, Some(1168), None];
193
194        let args = vec![
195            ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
196            ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
197        ];
198
199        let vector = f
200            .invoke_with_args(ScalarFunctionArgs {
201                args,
202                arg_fields: vec![],
203                number_rows: 4,
204                return_field: Arc::new(Field::new(
205                    "x",
206                    DataType::Timestamp(TimeUnit::Second, None),
207                    true,
208                )),
209                config_options: Arc::new(ConfigOptions::new()),
210            })
211            .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
212            .map(|mut a| a.remove(0))
213            .unwrap();
214        let vector = vector.as_primitive::<Date32Type>();
215
216        assert_eq!(4, vector.len());
217        for (i, _t) in dates.iter().enumerate() {
218            let result = results.get(i).unwrap();
219
220            if let Some(x) = result {
221                assert!(vector.is_valid(i));
222                assert_eq!(vector.value(i), *x);
223            } else {
224                assert!(vector.is_null(i));
225            }
226        }
227    }
228}