common_function/scalars/date/
date_add.rs1use std::fmt;
16
17use common_query::error::ArrowComputeSnafu;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::compute::kernels::numeric;
21use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
22use snafu::ResultExt;
23
24use crate::function::{Function, extract_args};
25use crate::helper;
26
27#[derive(Clone, Debug)]
31pub(crate) struct DateAddFunction {
32 signature: Signature,
33}
34
35impl Default for DateAddFunction {
36 fn default() -> Self {
37 Self {
38 signature: helper::one_of_sigs2(
39 vec![
40 DataType::Date32,
41 DataType::Timestamp(TimeUnit::Second, None),
42 DataType::Timestamp(TimeUnit::Millisecond, None),
43 DataType::Timestamp(TimeUnit::Microsecond, None),
44 DataType::Timestamp(TimeUnit::Nanosecond, None),
45 ],
46 vec![
47 DataType::Interval(IntervalUnit::MonthDayNano),
48 DataType::Interval(IntervalUnit::YearMonth),
49 DataType::Interval(IntervalUnit::DayTime),
50 ],
51 ),
52 }
53 }
54}
55
56const NAME: &str = "date_add";
57
58impl Function for DateAddFunction {
59 fn name(&self) -> &str {
60 NAME
61 }
62
63 fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
64 Ok(input_types[0].clone())
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn invoke_with_args(
72 &self,
73 args: ScalarFunctionArgs,
74 ) -> datafusion_common::Result<ColumnarValue> {
75 let [left, right] = extract_args(self.name(), &args)?;
76
77 let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
78 Ok(ColumnarValue::Array(result))
79 }
80}
81
82impl fmt::Display for DateAddFunction {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 write!(f, "DATE_ADD")
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use arrow_schema::Field;
93 use datafusion::arrow::array::{
94 Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
95 TimestampSecondArray,
96 };
97 use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
98 use datafusion_common::config::ConfigOptions;
99 use datafusion_expr::{TypeSignature, Volatility};
100
101 use super::{DateAddFunction, *};
102
103 #[test]
104 fn test_date_add_misc() {
105 let f = DateAddFunction::default();
106 assert_eq!("date_add", f.name());
107 assert_eq!(
108 DataType::Timestamp(TimeUnit::Microsecond, None),
109 f.return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)])
110 .unwrap()
111 );
112 assert_eq!(
113 DataType::Timestamp(TimeUnit::Second, None),
114 f.return_type(&[DataType::Timestamp(TimeUnit::Second, None)])
115 .unwrap()
116 );
117 assert_eq!(
118 DataType::Date32,
119 f.return_type(&[DataType::Date32]).unwrap()
120 );
121 assert!(
122 matches!(f.signature(),
123 Signature {
124 type_signature: TypeSignature::OneOf(sigs),
125 volatility: Volatility::Immutable
126 } if sigs.len() == 15),
127 "{:?}",
128 f.signature()
129 );
130 }
131
132 #[test]
133 fn test_timestamp_date_add() {
134 let f = DateAddFunction::default();
135
136 let times = vec![Some(123), None, Some(42), None];
137 let intervals = vec![
139 IntervalDayTime::new(0, 1000),
140 IntervalDayTime::new(0, 2000),
141 IntervalDayTime::new(0, 3000),
142 IntervalDayTime::new(0, 1000),
143 ];
144 let results = [Some(124), None, Some(45), None];
145
146 let args = vec![
147 ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
148 ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
149 ];
150
151 let vector = f
152 .invoke_with_args(ScalarFunctionArgs {
153 args,
154 arg_fields: vec![],
155 number_rows: 4,
156 return_field: Arc::new(Field::new(
157 "x",
158 DataType::Timestamp(TimeUnit::Second, None),
159 true,
160 )),
161 config_options: Arc::new(ConfigOptions::new()),
162 })
163 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
164 .map(|mut a| a.remove(0))
165 .unwrap();
166 let vector = vector.as_primitive::<TimestampSecondType>();
167
168 assert_eq!(4, vector.len());
169 for (i, _t) in times.iter().enumerate() {
170 let result = results.get(i).unwrap();
171
172 if let Some(x) = result {
173 assert!(vector.is_valid(i));
174 assert_eq!(vector.value(i), *x);
175 } else {
176 assert!(vector.is_null(i));
177 }
178 }
179 }
180
181 #[test]
182 fn test_date_date_add() {
183 let f = DateAddFunction::default();
184
185 let dates = vec![Some(123), None, Some(42), None];
186 let intervals = vec![1, 2, 3, 1];
188 let results = [Some(154), None, Some(131), None];
189
190 let args = vec![
191 ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
192 ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
193 ];
194
195 let vector = f
196 .invoke_with_args(ScalarFunctionArgs {
197 args,
198 arg_fields: vec![],
199 number_rows: 4,
200 return_field: Arc::new(Field::new(
201 "x",
202 DataType::Timestamp(TimeUnit::Second, None),
203 true,
204 )),
205 config_options: Arc::new(ConfigOptions::new()),
206 })
207 .and_then(|v| ColumnarValue::values_to_arrays(&[v]))
208 .map(|mut a| a.remove(0))
209 .unwrap();
210 let vector = vector.as_primitive::<Date32Type>();
211
212 assert_eq!(4, vector.len());
213 for (i, _t) in dates.iter().enumerate() {
214 let result = results.get(i).unwrap();
215
216 if let Some(x) = result {
217 assert!(vector.is_valid(i));
218 assert_eq!(vector.value(i), *x);
219 } else {
220 assert!(vector.is_null(i));
221 }
222 }
223 }
224}