common_function/scalars/json/
json_get_rewriter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(test)]
16use std::sync::Arc;
17
18use arrow_schema::{DataType, TimeUnit};
19use datafusion::common::config::ConfigOptions;
20use datafusion::common::tree_node::Transformed;
21use datafusion::common::{DFSchema, Result};
22use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
23use datafusion::scalar::ScalarValue;
24use datafusion_expr::expr::ScalarFunction;
25use datafusion_expr::{Cast, Expr};
26
27use crate::scalars::json::JsonGetWithType;
28
29#[derive(Debug)]
30pub struct JsonGetRewriter;
31
32impl FunctionRewrite for JsonGetRewriter {
33    fn name(&self) -> &'static str {
34        "JsonGetRewriter"
35    }
36
37    fn rewrite(
38        &self,
39        expr: Expr,
40        _schema: &DFSchema,
41        _config: &ConfigOptions,
42    ) -> Result<Transformed<Expr>> {
43        let transform = match &expr {
44            Expr::Cast(cast) => rewrite_json_get_cast(cast),
45            Expr::ScalarFunction(scalar_func) => rewrite_arrow_cast_json_get(scalar_func),
46            _ => None,
47        };
48        Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
49    }
50}
51
52fn is_json_get_function_call(scalar_func: &ScalarFunction) -> bool {
53    scalar_func.func.name().to_ascii_lowercase() == JsonGetWithType::NAME
54        && scalar_func.args.len() == 2
55}
56
57fn rewrite_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
58    let scalar_func = extract_scalar_function(&cast.expr)?;
59    if is_json_get_function_call(scalar_func) {
60        let null_expr = Expr::Literal(ScalarValue::Null, None);
61        let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
62            expr: Box::new(null_expr),
63            data_type: cast.data_type.clone(),
64        });
65
66        let mut args = scalar_func.args.clone();
67        args.push(null_cast);
68
69        Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
70            func: scalar_func.func.clone(),
71            args,
72        })))
73    } else {
74        None
75    }
76}
77
78// Handle Arrow cast function: cast(json_get(a, 'path'), 'Int64')
79fn rewrite_arrow_cast_json_get(scalar_func: &ScalarFunction) -> Option<Transformed<Expr>> {
80    // Check if this is an Arrow cast function
81    // The function name might be "arrow_cast" or similar
82    let func_name = scalar_func.func.name().to_ascii_lowercase();
83    if !func_name.contains("arrow_cast") {
84        return None;
85    }
86
87    // Arrow cast function should have exactly 2 arguments:
88    // 1. The expression to cast (could be json_get)
89    // 2. The target type as a string literal
90    if scalar_func.args.len() != 2 {
91        return None;
92    }
93
94    // Extract the inner json_get function
95    let json_get_func = extract_scalar_function(&scalar_func.args[0])?;
96
97    // Check if it's a json_get function
98    if is_json_get_function_call(json_get_func) {
99        // Get the target type from the second argument
100        let target_type = extract_string_literal(&scalar_func.args[1])?;
101        let data_type = parse_data_type_from_string(&target_type)?;
102
103        // Create the null expression with the same type
104        let null_expr = Expr::Literal(ScalarValue::Null, None);
105        let null_cast = Expr::Cast(datafusion::logical_expr::expr::Cast {
106            expr: Box::new(null_expr),
107            data_type,
108        });
109
110        // Create the new json_get_with_type function with the null parameter
111        let mut args = json_get_func.args.clone();
112        args.push(null_cast);
113
114        Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
115            func: json_get_func.func.clone(),
116            args,
117        })))
118    } else {
119        None
120    }
121}
122
123// Extract string literal from an expression
124fn extract_string_literal(expr: &Expr) -> Option<String> {
125    match expr {
126        Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Some(s.clone()),
127        _ => None,
128    }
129}
130
131// Parse a data type from a string representation
132fn parse_data_type_from_string(type_str: &str) -> Option<DataType> {
133    match type_str.to_lowercase().as_str() {
134        "int8" | "tinyint" => Some(DataType::Int8),
135        "int16" | "smallint" => Some(DataType::Int16),
136        "int32" | "integer" => Some(DataType::Int32),
137        "int64" | "bigint" => Some(DataType::Int64),
138        "uint8" => Some(DataType::UInt8),
139        "uint16" => Some(DataType::UInt16),
140        "uint32" => Some(DataType::UInt32),
141        "uint64" => Some(DataType::UInt64),
142        "float32" | "real" => Some(DataType::Float32),
143        "float64" | "double" => Some(DataType::Float64),
144        "boolean" | "bool" => Some(DataType::Boolean),
145        "string" | "text" | "varchar" => Some(DataType::Utf8),
146        "timestamp" => Some(DataType::Timestamp(TimeUnit::Microsecond, None)),
147        "date" => Some(DataType::Date32),
148        _ => None,
149    }
150}
151
152fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
153    match expr {
154        Expr::ScalarFunction(func) => Some(func),
155        _ => None,
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use arrow_schema::DataType;
162    use datafusion::common::DFSchema;
163    use datafusion::common::config::ConfigOptions;
164    use datafusion::logical_expr::expr::Cast;
165    use datafusion::scalar::ScalarValue;
166    use datafusion_expr::Expr;
167    use datafusion_expr::expr::ScalarFunction;
168
169    use super::*;
170
171    #[test]
172    fn test_rewrite_regular_cast() {
173        let rewriter = JsonGetRewriter;
174        let schema = DFSchema::empty();
175        let config = ConfigOptions::new();
176
177        // Create a json_get function
178        let json_expr = Expr::ScalarFunction(ScalarFunction {
179            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
180                crate::scalars::json::JsonGetWithType::default(),
181            ))),
182            args: vec![
183                Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
184                Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
185            ],
186        });
187
188        // Create a cast expression: json_get(...)::int8
189        let cast_expr = Expr::Cast(Cast {
190            expr: Box::new(json_expr),
191            data_type: DataType::Int8,
192        });
193
194        // Apply the rewriter
195        let result = rewriter.rewrite(cast_expr, &schema, &config).unwrap();
196
197        // Verify the result is transformed
198        assert!(result.transformed);
199
200        // Verify the result is a ScalarFunction
201        match result.data {
202            Expr::ScalarFunction(func) => {
203                // Should have 3 arguments now (original 2 + null cast)
204                assert_eq!(func.args.len(), 3);
205
206                // First argument should be the original json
207                match &func.args[0] {
208                    Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
209                        assert_eq!(json, "{\"a\":1}");
210                    }
211                    _ => panic!("First argument should be a string literal"),
212                }
213
214                // Second argument should be the path
215                match &func.args[1] {
216                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
217                        assert_eq!(path, "$.a");
218                    }
219                    _ => panic!("Second argument should be a string literal"),
220                }
221
222                // Third argument should be a null cast to Int8
223                match &func.args[2] {
224                    Expr::Cast(Cast { expr, data_type }) => {
225                        assert_eq!(*data_type, DataType::Int8);
226                        match expr.as_ref() {
227                            Expr::Literal(ScalarValue::Null, _) => {}
228                            _ => panic!("Third argument should be a null cast"),
229                        }
230                    }
231                    _ => panic!("Third argument should be a cast expression"),
232                }
233            }
234            _ => panic!("Result should be a ScalarFunction"),
235        }
236    }
237
238    #[test]
239    fn test_rewrite_arrow_cast_function() {
240        let rewriter = JsonGetRewriter;
241        let schema = DFSchema::empty();
242        let config = ConfigOptions::new();
243
244        // Create a parse_json function
245        let parse_json_expr = Expr::ScalarFunction(ScalarFunction {
246            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
247                crate::scalars::json::ParseJsonFunction::default(),
248            ))),
249            args: vec![Expr::Literal(
250                ScalarValue::Utf8(Some("{\"a\":1}".to_string())),
251                None,
252            )],
253        });
254
255        // Create a json_get function
256        let json_get_expr = Expr::ScalarFunction(ScalarFunction {
257            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
258                crate::scalars::json::JsonGetWithType::default(),
259            ))),
260            args: vec![
261                parse_json_expr,
262                Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None),
263            ],
264        });
265
266        // Create an arrow cast function: cast(json_get(...), 'Int64')
267        // Note: ArrowCastFunc doesn't exist in this codebase, so this test uses a simple cast instead
268        let arrow_cast_expr = Expr::Cast(Cast {
269            expr: Box::new(json_get_expr),
270            data_type: DataType::Int64,
271        });
272
273        // Apply the rewriter
274        let result = rewriter.rewrite(arrow_cast_expr, &schema, &config).unwrap();
275
276        // Verify the result is transformed
277        assert!(result.transformed);
278
279        // Verify the result is a ScalarFunction (json_get_with_type)
280        match result.data {
281            Expr::ScalarFunction(func) => {
282                // Should have 3 arguments now (original 2 + null cast)
283                assert_eq!(func.args.len(), 3);
284
285                // First argument should be the original parse_json function
286                match &func.args[0] {
287                    Expr::ScalarFunction(parse_func) => {
288                        // Verify it's a parse_json function with the right argument
289                        assert!(
290                            parse_func
291                                .func
292                                .name()
293                                .to_ascii_lowercase()
294                                .contains("parse_json")
295                        );
296                        assert_eq!(parse_func.args.len(), 1);
297                        match &parse_func.args[0] {
298                            Expr::Literal(ScalarValue::Utf8(Some(json)), _) => {
299                                assert_eq!(json, "{\"a\":1}");
300                            }
301                            _ => panic!("Parse json argument should be a string literal"),
302                        }
303                    }
304                    _ => panic!("First argument should be a parse_json function"),
305                }
306
307                // Second argument should be the path
308                match &func.args[1] {
309                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
310                        assert_eq!(path, "a");
311                    }
312                    _ => panic!("Second argument should be a string literal"),
313                }
314
315                // Third argument should be a null cast to Int64
316                match &func.args[2] {
317                    Expr::Cast(Cast { expr, data_type }) => {
318                        assert_eq!(*data_type, DataType::Int64);
319                        match expr.as_ref() {
320                            Expr::Literal(ScalarValue::Null, _) => {}
321                            _ => panic!("Third argument should be a null cast"),
322                        }
323                    }
324                    _ => panic!("Third argument should be a cast expression"),
325                }
326            }
327            _ => panic!("Result should be a ScalarFunction"),
328        }
329    }
330
331    #[test]
332    fn test_no_rewrite_for_other_functions() {
333        let rewriter = JsonGetRewriter;
334        let schema = DFSchema::empty();
335        let config = ConfigOptions::new();
336
337        // Create a non-json function
338        let other_func = Expr::ScalarFunction(ScalarFunction {
339            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
340                crate::scalars::test::TestAndFunction::default(),
341            ))),
342            args: vec![Expr::Literal(ScalarValue::Int64(Some(4)), None)],
343        });
344
345        // Apply the rewriter
346        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
347
348        // Verify the result is not transformed
349        assert!(!result.transformed);
350    }
351
352    #[test]
353    fn test_no_rewrite_for_non_cast_functions() {
354        let rewriter = JsonGetRewriter;
355        let schema = DFSchema::empty();
356        let config = ConfigOptions::new();
357
358        // Create a scalar function that doesn't contain "cast"
359        let other_func = Expr::ScalarFunction(ScalarFunction {
360            func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
361                crate::scalars::test::TestAndFunction::default(),
362            ))),
363            args: vec![
364                Expr::ScalarFunction(ScalarFunction {
365                    func: Arc::new(crate::scalars::udf::create_udf(Arc::new(
366                        crate::scalars::json::JsonGetWithType::default(),
367                    ))),
368                    args: vec![
369                        Expr::Literal(ScalarValue::Utf8(Some("{\"a\":1}".to_string())), None),
370                        Expr::Literal(ScalarValue::Utf8(Some("$.a".to_string())), None),
371                    ],
372                }),
373                Expr::Literal(ScalarValue::Utf8(Some("Int64".to_string())), None),
374            ],
375        });
376
377        // Apply the rewriter
378        let result = rewriter.rewrite(other_func, &schema, &config).unwrap();
379
380        // Verify the result is not transformed
381        assert!(!result.transformed);
382    }
383}