common_function/scalars/json/
json_path_match.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::{self, Display};
16use std::sync::Arc;
17
18use arrow::compute;
19use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature};
22
23use crate::function::{Function, extract_args};
24use crate::helper;
25
26/// Check if the given JSON data match the given JSON path's predicate.
27#[derive(Clone, Debug)]
28pub(crate) struct JsonPathMatchFunction {
29    signature: Signature,
30}
31
32impl Default for JsonPathMatchFunction {
33    fn default() -> Self {
34        Self {
35            // TODO(LFC): Use a more clear type here instead of "Binary" for Json input, once we have a "Json" type.
36            signature: helper::one_of_sigs2(
37                vec![DataType::Binary, DataType::BinaryView],
38                vec![DataType::Utf8, DataType::Utf8View],
39            ),
40        }
41    }
42}
43
44const NAME: &str = "json_path_match";
45
46impl Function for JsonPathMatchFunction {
47    fn name(&self) -> &str {
48        NAME
49    }
50
51    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
52        Ok(DataType::Boolean)
53    }
54
55    fn signature(&self) -> &Signature {
56        &self.signature
57    }
58
59    fn invoke_with_args(
60        &self,
61        args: ScalarFunctionArgs,
62    ) -> datafusion_common::Result<ColumnarValue> {
63        let [arg0, arg1] = extract_args(self.name(), &args)?;
64        let arg0 = compute::cast(&arg0, &DataType::BinaryView)?;
65        let jsons = arg0.as_binary_view();
66        let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
67        let paths = arg1.as_string_view();
68
69        let size = jsons.len();
70        let mut builder = BooleanBuilder::with_capacity(size);
71
72        for i in 0..size {
73            let json = jsons.is_valid(i).then(|| jsons.value(i));
74            let path = paths.is_valid(i).then(|| paths.value(i));
75
76            let result = match (json, path) {
77                (Some(json), Some(path)) => {
78                    if !jsonb::is_null(json) {
79                        let json_path = jsonb::jsonpath::parse_json_path(path.as_bytes());
80                        match json_path {
81                            Ok(json_path) => jsonb::path_match(json, json_path).ok(),
82                            Err(_) => None,
83                        }
84                    } else {
85                        None
86                    }
87                }
88                _ => None,
89            };
90            builder.append_option(result);
91        }
92
93        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
94    }
95}
96
97impl Display for JsonPathMatchFunction {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        write!(f, "JSON_PATH_MATCH")
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use std::sync::Arc;
106
107    use arrow_schema::Field;
108    use datafusion_common::arrow::array::{BinaryArray, StringArray};
109
110    use super::*;
111
112    #[test]
113    fn test_json_path_match_function() {
114        let json_path_match = JsonPathMatchFunction::default();
115
116        assert_eq!("json_path_match", json_path_match.name());
117        assert_eq!(
118            DataType::Boolean,
119            json_path_match.return_type(&[DataType::Binary]).unwrap()
120        );
121
122        let json_strings = [
123            Some(r#"{"a": {"b": 2}, "b": 2, "c": 3}"#.to_string()),
124            Some(r#"{"a": 1, "b": [1,2,3]}"#.to_string()),
125            Some(r#"{"a": 1 ,"b": [1,2,3]}"#.to_string()),
126            Some(r#"[1,2,3]"#.to_string()),
127            Some(r#"{"a":1,"b":[1,2,3]}"#.to_string()),
128            Some(r#"null"#.to_string()),
129            Some(r#"null"#.to_string()),
130        ];
131
132        let paths = vec![
133            Some("$.a.b == 2".to_string()),
134            Some("$.b[1 to last] >= 2".to_string()),
135            Some("$.c > 0".to_string()),
136            Some("$[0 to last] > 0".to_string()),
137            Some(r#"null"#.to_string()),
138            Some("$.c > 0".to_string()),
139            Some(r#"null"#.to_string()),
140        ];
141
142        let results = [
143            Some(true),
144            Some(true),
145            Some(false),
146            Some(true),
147            None,
148            None,
149            None,
150        ];
151
152        let jsonbs = json_strings
153            .into_iter()
154            .map(|s| s.map(|json| jsonb::parse_value(json.as_bytes()).unwrap().to_vec()))
155            .collect::<Vec<_>>();
156
157        let args = ScalarFunctionArgs {
158            args: vec![
159                ColumnarValue::Array(Arc::new(BinaryArray::from_iter(jsonbs))),
160                ColumnarValue::Array(Arc::new(StringArray::from_iter(paths))),
161            ],
162            arg_fields: vec![],
163            number_rows: 7,
164            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
165            config_options: Arc::new(Default::default()),
166        };
167        let result = json_path_match
168            .invoke_with_args(args)
169            .and_then(|x| x.to_array(7))
170            .unwrap();
171        let vector = result.as_boolean();
172
173        assert_eq!(7, vector.len());
174        for (actual, expected) in vector.iter().zip(results) {
175            assert_eq!(actual, expected);
176        }
177    }
178}