Skip to main content

query/optimizer/
json_type_concretize.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use arrow_schema::DataType;
18use common_function::scalars::json::json_get::JsonGetWithType;
19use datafusion::datasource::DefaultTableSource;
20use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
21use datafusion_common::{Result, plan_datafusion_err, plan_err};
22use datafusion_expr::{Expr, LogicalPlan};
23use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
24use datatypes::types::json_type::{JsonNativeType, JsonObjectType};
25
26use crate::dummy_catalog::DummyTableProvider;
27
28/// Concretize (deduce) the expected JSON type from query.
29/// For example, we can concretize a JSON type of `{ a: { b: Number } }` from `select j.a.b::Int64`.
30/// The JSON type will be later set into the scan request, for converting the JSON arrays.
31#[derive(Debug)]
32pub(crate) struct JsonTypeConcretizeRule;
33
34impl OptimizerRule for JsonTypeConcretizeRule {
35    fn name(&self) -> &str {
36        "JsonTypeConcretizeRule"
37    }
38
39    fn rewrite(
40        &self,
41        plan: LogicalPlan,
42        _config: &dyn OptimizerConfig,
43    ) -> Result<Transformed<LogicalPlan>> {
44        let json_types = deduce_json_types(&plan)?;
45        if json_types.is_empty() {
46            return Ok(Transformed::no(plan));
47        }
48
49        plan.transform_down(|plan| match &plan {
50            LogicalPlan::TableScan(table_scan) => {
51                let Some(source) = table_scan
52                    .source
53                    .as_any()
54                    .downcast_ref::<DefaultTableSource>()
55                else {
56                    return Ok(Transformed::no(plan));
57                };
58
59                let Some(adapter) = source
60                    .table_provider
61                    .as_any()
62                    .downcast_ref::<DummyTableProvider>()
63                else {
64                    return Ok(Transformed::no(plan));
65                };
66
67                adapter.with_json_type_hint(json_types.clone());
68                Ok(Transformed::yes(plan))
69            }
70            _ => Ok(Transformed::no(plan)),
71        })
72    }
73}
74
75fn deduce_json_types(plan: &LogicalPlan) -> Result<HashMap<String, JsonNativeType>> {
76    let mut json_types = HashMap::<String, JsonNativeType>::new();
77
78    plan.apply(|plan| {
79        for expr in plan.expressions() {
80            expr.apply(|expr| {
81                if let Some((column, json_type)) = deduce_json_type(expr)? {
82                    json_types.entry(column).or_default().merge(&json_type);
83                }
84                Ok(TreeNodeRecursion::Continue)
85            })?;
86        }
87        Ok(TreeNodeRecursion::Continue)
88    })?;
89    Ok(json_types)
90}
91
92fn deduce_json_type(expr: &Expr) -> Result<Option<(String, JsonNativeType)>> {
93    let f = match expr {
94        Expr::ScalarFunction(f) if f.name().eq_ignore_ascii_case(JsonGetWithType::NAME) => f,
95        _ => return Ok(None),
96    };
97
98    let Some(Expr::Column(column)) = f.args.first() else {
99        return plan_err!(
100            "First argument of {} is expected to be a column expr, actual: {:?}",
101            JsonGetWithType::NAME,
102            f.args.first()
103        );
104    };
105
106    let Some(path) = f
107        .args
108        .get(1)
109        .and_then(|expr| expr.as_literal())
110        .and_then(|x| x.try_as_str())
111        .flatten()
112    else {
113        return plan_err!(
114            "Second argument of {} is expected to be a string literal, actual: {:?}",
115            JsonGetWithType::NAME,
116            f.args.get(1)
117        );
118    };
119
120    let with_type = f
121        .args
122        .get(2)
123        .and_then(|expr| expr.as_literal())
124        .map(|x| x.data_type())
125        .unwrap_or(DataType::Utf8View);
126    let with_type =
127        JsonNativeType::try_from(&with_type).map_err(|e| plan_datafusion_err!("{e:?}"))?;
128
129    let mut split = path.rsplit(".");
130    let Some(leaf) = split.next() else {
131        return Ok(Some((column.name.clone(), JsonNativeType::String)));
132    };
133
134    let mut object = JsonObjectType::new();
135    object.insert(leaf.to_string(), with_type);
136    let mut root = JsonNativeType::Object(object);
137
138    for s in split {
139        let mut object = JsonObjectType::new();
140        object.insert(s.to_string(), root);
141        root = JsonNativeType::Object(object);
142    }
143
144    Ok(Some((column.name.clone(), root)))
145}
146
147#[cfg(test)]
148mod tests {
149    use std::sync::Arc;
150
151    use common_function::scalars::udf::create_udf;
152    use datafusion::datasource::provider_as_source;
153    use datafusion_common::{Column, ScalarValue};
154    use datafusion_expr::expr::ScalarFunction;
155    use datafusion_expr::{LogicalPlanBuilder, col};
156    use datafusion_optimizer::OptimizerContext;
157    use store_api::storage::RegionId;
158
159    use super::*;
160    use crate::optimizer::test_util::mock_table_provider;
161
162    fn json_get_expr(base: Expr, path: Expr, with_type: Option<DataType>) -> Result<Expr> {
163        let json_get = Arc::new(create_udf(Arc::new(JsonGetWithType::default())));
164        let mut args = vec![base, path];
165        if let Some(with_type) = with_type {
166            let with_type = ScalarValue::try_new_null(&with_type)?;
167            args.push(Expr::Literal(with_type, None));
168        }
169        Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
170            json_get, args,
171        )))
172    }
173
174    fn path_expr(path: &str) -> Expr {
175        Expr::Literal(ScalarValue::Utf8(Some(path.to_string())), None)
176    }
177
178    fn build_plan(exprs: Vec<Expr>) -> Result<(Arc<DummyTableProvider>, LogicalPlan)> {
179        let provider = Arc::new(mock_table_provider(RegionId::new(1024, 1)));
180        let plan = LogicalPlanBuilder::scan("t", provider_as_source(provider.clone()), None)?
181            .project(exprs)?
182            .build()?;
183        Ok((provider, plan))
184    }
185
186    #[test]
187    fn test_json_type_concretize_rule_rewrite() -> Result<()> {
188        let exprs = vec![
189            json_get_expr(col("k0"), path_expr("a.b"), Some(DataType::Int64))?.alias("ab"),
190            json_get_expr(col("k0"), path_expr("a.c"), None)?.alias("ac"),
191            json_get_expr(col("k0"), path_expr("d"), Some(DataType::Boolean))?.alias("d"),
192        ];
193        let (provider, plan) = build_plan(exprs)?;
194
195        assert!(
196            JsonTypeConcretizeRule
197                .rewrite(plan, &OptimizerContext::default())?
198                .transformed
199        );
200
201        let expected = JsonNativeType::Object(JsonObjectType::from([
202            (
203                "a".to_string(),
204                JsonNativeType::Object(JsonObjectType::from([
205                    ("b".to_string(), JsonNativeType::i64()),
206                    ("c".to_string(), JsonNativeType::String),
207                ])),
208            ),
209            ("d".to_string(), JsonNativeType::Bool),
210        ]));
211
212        let request = provider.scan_request();
213        assert_eq!(1, request.json_type_hint.len());
214        assert_eq!(Some(&expected), request.json_type_hint.get("k0"));
215        Ok(())
216    }
217
218    #[test]
219    fn test_json_type_concretize_rule_conflict_to_variant() -> Result<()> {
220        let exprs = vec![
221            json_get_expr(col("k0"), path_expr("a"), Some(DataType::Int64))?.alias("a_num"),
222            json_get_expr(col("k0"), path_expr("a.b"), Some(DataType::Boolean))?.alias("a_obj"),
223        ];
224        let (provider, plan) = build_plan(exprs)?;
225
226        assert!(
227            JsonTypeConcretizeRule
228                .rewrite(plan, &OptimizerContext::default())?
229                .transformed
230        );
231
232        let expected = JsonNativeType::Object(JsonObjectType::from([(
233            "a".to_string(),
234            JsonNativeType::Variant,
235        )]));
236        assert_eq!(
237            Some(&expected),
238            provider.scan_request().json_type_hint.get("k0")
239        );
240        Ok(())
241    }
242
243    #[test]
244    fn test_json_type_concretize_rule_no_json_get() -> Result<()> {
245        let (provider, plan) = build_plan(vec![col("k0"), col("v0")])?;
246
247        assert!(
248            !JsonTypeConcretizeRule
249                .rewrite(plan, &OptimizerContext::default())?
250                .transformed
251        );
252        assert!(provider.scan_request().json_type_hint.is_empty());
253        Ok(())
254    }
255
256    #[test]
257    fn test_deduce_json_type_with_non_column_base() -> Result<()> {
258        let expr = json_get_expr(
259            Expr::Literal(ScalarValue::Utf8(Some("{}".to_string())), None),
260            path_expr("a"),
261            Some(DataType::Int64),
262        )?;
263
264        let err = deduce_json_type(&expr).unwrap_err();
265        assert!(
266            err.to_string()
267                .contains("First argument of json_get is expected to be a column expr")
268        );
269        Ok(())
270    }
271
272    #[test]
273    fn test_deduce_json_type_with_non_literal_path() -> Result<()> {
274        let expr = json_get_expr(
275            Expr::Column(Column::new_unqualified("k0")),
276            Expr::Column(Column::new_unqualified("path_col")),
277            Some(DataType::Int64),
278        )?;
279
280        let err = deduce_json_type(&expr).unwrap_err();
281        assert!(
282            err.to_string()
283                .contains("Second argument of json_get is expected to be a string literal")
284        );
285        Ok(())
286    }
287
288    #[test]
289    fn test_deduce_json_type_default_string() -> Result<()> {
290        let expr = json_get_expr(
291            Expr::Column(Column::new_unqualified("k0")),
292            path_expr("a.b"),
293            None,
294        )?;
295
296        let deduced = deduce_json_type(&expr)?;
297        let expected = JsonNativeType::Object(JsonObjectType::from([(
298            "a".to_string(),
299            JsonNativeType::Object(JsonObjectType::from([(
300                "b".to_string(),
301                JsonNativeType::String,
302            )])),
303        )]));
304
305        assert_eq!(Some(("k0".to_string(), expected)), deduced);
306        Ok(())
307    }
308}