common_function/scalars/json/
json_get_rewriter.rs1#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33 fn name(&self) -> &'static str {
34 "JsonGetRewriter"
35 }
36
37 fn rewrite(
38 &self,
39 expr: Expr,
40 _schema: &DFSchema,
41 _config: &ConfigOptions,
42 ) -> Result<Transformed<Expr>> {
43 let transform = match &expr {
44 Expr::Cast(cast) => rewrite_json_get_cast(cast),
45 Expr::ScalarFunction(scalar_func) => rewrite_arrow_cast_json_get(scalar_func),
46 _ => None,
47 };
48 Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
49 }
50}
51
52fn is_json_get_function_call(scalar_func: &ScalarFunction) -> bool {
53 scalar_func.func.name().to_ascii_lowercase() == JsonGetWithType::NAME
54 && scalar_func.args.len() == 2
55}
56
57fn rewrite_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
58 let scalar_func = extract_scalar_function(&cast.expr)?;
59 if is_json_get_function_call(scalar_func) {
60 let null_expr = Expr::Literal(ScalarValue::Null, None);
61 let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
62 expr: Box::new(null_expr),
63 data_type: cast.data_type.clone(),
64 });
65
66 let mut args = scalar_func.args.clone();
67 args.push(null_cast);
68
69 Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
70 func: scalar_func.func.clone(),
71 args,
72 })))
73 } else {
74 None
75 }
76}
77
78fn rewrite_arrow_cast_json_get(scalar_func: &ScalarFunction) -> Option<Transformed<Expr>> {
80 let func_name = scalar_func.func.name().to_ascii_lowercase();
83 if !func_name.contains("arrow_cast") {
84 return None;
85 }
86
87 if scalar_func.args.len() != 2 {
91 return None;
92 }
93
94 let json_get_func = extract_scalar_function(&scalar_func.args[0])?;
96
97 if is_json_get_function_call(json_get_func) {
99 let target_type = extract_string_literal(&scalar_func.args[1])?;
101 let data_type = parse_data_type_from_string(&target_type)?;
102
103 let null_expr = Expr::Literal(ScalarValue::Null, None);
105 let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
106 expr: Box::new(null_expr),
107 data_type,
108 });
109
110 let mut args = json_get_func.args.clone();
112 args.push(null_cast);
113
114 Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
115 func: json_get_func.func.clone(),
116 args,
117 })))
118 } else {
119 None
120 }
121}
122
123fn extract_string_literal(expr: &Expr) -> Option<String> {
125 match expr {
126 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s.clone()),
127 _ => None,
128 }
129}
130
131fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
133 match type_str.to_lowercase().as_str() {
134 "int8" | "tinyint" => Some(DataType::Int8),
135 "int16" | "smallint" => Some(DataType::Int16),
136 "int32" | "integer" => Some(DataType::Int32),
137 "int64" | "bigint" => Some(DataType::Int64),
138 "uint8" => Some(DataType::UInt8),
139 "uint16" => Some(DataType::UInt16),
140 "uint32" => Some(DataType::UInt32),
141 "uint64" => Some(DataType::UInt64),
142 "float32" | "real" => Some(DataType::Float32),
143 "float64" | "double" => Some(DataType::Float64),
144 "boolean" | "bool" => Some(DataType::Boolean),
145 "string" | "text" | "varchar" => Some(DataType::Utf8),
146 "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
147 "date" => Some(DataType::Date32),
148 _ => None,
149 }
150}
151
152fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
153 match expr {
154 Expr::ScalarFunction(func) => Some(func),
155 _ => None,
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use arrow_schema::DataType;
162 use datafusion::common::DFSchema;
163 use datafusion::common::config::ConfigOptions;
164 use datafusion::logical_expr::expr::Cast;
165 use datafusion::scalar::ScalarValue;
166 use datafusion_expr::Expr;
167 use datafusion_expr::expr::ScalarFunction;
168
169 use super::*;
170
171 #[test]
172 fn test_rewrite_regular_cast() {
173 let rewriter = JsonGetRewriter;
174 let schema = DFSchema::empty();
175 let config = ConfigOptions::new();
176
177 let json_expr = Expr::ScalarFunction(ScalarFunction {
179 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
180 crate::scalars::json::JsonGetWithType::default(),
181 ))),
182 args: vec![
183 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
184 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
185 ],
186 });
187
188 let cast_expr = Expr::Cast(Cast {
190 expr: Box::new(json_expr),
191 data_type: DataType::Int8,
192 });
193
194 let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
196
197 assert!(result.transformed);
199
200 match result.data {
202 Expr::ScalarFunction(func) => {
203 assert_eq!(func.args.len(), 3);
205
206 match &func.args[0] {
208 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
209 assert_eq!(json, "{\"a\":1}");
210 }
211 _ => panic!("First argument should be a string literal"),
212 }
213
214 match &func.args[1] {
216 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
217 assert_eq!(path, "$.a");
218 }
219 _ => panic!("Second argument should be a string literal"),
220 }
221
222 match &func.args[2] {
224 Expr::Cast(Cast { expr, data_type }) => {
225 assert_eq!(*data_type, DataType::Int8);
226 match expr.as_ref() {
227 Expr::Literal(ScalarValue::Null, _) => {}
228 _ => panic!("Third argument should be a null cast"),
229 }
230 }
231 _ => panic!("Third argument should be a cast expression"),
232 }
233 }
234 _ => panic!("Result should be a ScalarFunction"),
235 }
236 }
237
238 #[test]
239 fn test_rewrite_arrow_cast_function() {
240 let rewriter = JsonGetRewriter;
241 let schema = DFSchema::empty();
242 let config = ConfigOptions::new();
243
244 let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
246 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
247 crate::scalars::json::ParseJsonFunction::default(),
248 ))),
249 args: vec![Expr::Literal(
250 ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
251 None,
252 )],
253 });
254
255 let json_get_expr = Expr::ScalarFunction(ScalarFunction {
257 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
258 crate::scalars::json::JsonGetWithType::default(),
259 ))),
260 args: vec![
261 parse_json_expr,
262 Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
263 ],
264 });
265
266 let arrow_cast_expr = Expr::Cast(Cast {
269 expr: Box::new(json_get_expr),
270 data_type: DataType::Int64,
271 });
272
273 let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
275
276 assert!(result.transformed);
278
279 match result.data {
281 Expr::ScalarFunction(func) => {
282 assert_eq!(func.args.len(), 3);
284
285 match &func.args[0] {
287 Expr::ScalarFunction(parse_func) => {
288 assert!(
290 parse_func
291 .func
292 .name()
293 .to_ascii_lowercase()
294 .contains("parse_json")
295 );
296 assert_eq!(parse_func.args.len(), 1);
297 match &parse_func.args[0] {
298 Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
299 assert_eq!(json, "{\"a\":1}");
300 }
301 _ => panic!("Parse json argument should be a string literal"),
302 }
303 }
304 _ => panic!("First argument should be a parse_json function"),
305 }
306
307 match &func.args[1] {
309 Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
310 assert_eq!(path, "a");
311 }
312 _ => panic!("Second argument should be a string literal"),
313 }
314
315 match &func.args[2] {
317 Expr::Cast(Cast { expr, data_type }) => {
318 assert_eq!(*data_type, DataType::Int64);
319 match expr.as_ref() {
320 Expr::Literal(ScalarValue::Null, _) => {}
321 _ => panic!("Third argument should be a null cast"),
322 }
323 }
324 _ => panic!("Third argument should be a cast expression"),
325 }
326 }
327 _ => panic!("Result should be a ScalarFunction"),
328 }
329 }
330
331 #[test]
332 fn test_no_rewrite_for_other_functions() {
333 let rewriter = JsonGetRewriter;
334 let schema = DFSchema::empty();
335 let config = ConfigOptions::new();
336
337 let other_func = Expr::ScalarFunction(ScalarFunction {
339 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
340 crate::scalars::test::TestAndFunction::default(),
341 ))),
342 args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
343 });
344
345 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
347
348 assert!(!result.transformed);
350 }
351
352 #[test]
353 fn test_no_rewrite_for_non_cast_functions() {
354 let rewriter = JsonGetRewriter;
355 let schema = DFSchema::empty();
356 let config = ConfigOptions::new();
357
358 let other_func = Expr::ScalarFunction(ScalarFunction {
360 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
361 crate::scalars::test::TestAndFunction::default(),
362 ))),
363 args: vec![
364 Expr::ScalarFunction(ScalarFunction {
365 func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
366 crate::scalars::json::JsonGetWithType::default(),
367 ))),
368 args: vec![
369 Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
370 Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
371 ],
372 }),
373 Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
374 ],
375 });
376
377 let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
379
380 assert!(!result.transformed);
382 }
383}