1use std::collections::HashMap;
16
17use arrow_schema::DataType;
18use common_function::scalars::json::json_get::JsonGetWithType;
19use datafusion::datasource::DefaultTableSource;
20use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
21use datafusion_common::{Result, plan_datafusion_err, plan_err};
22use datafusion_expr::{Expr, LogicalPlan};
23use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
24use datatypes::types::json_type::{JsonNativeType, JsonObjectType};
25
26use crate::dummy_catalog::DummyTableProvider;
27
28#[derive(Debug)]
32pub(crate) struct JsonTypeConcretizeRule;
33
34impl OptimizerRule for JsonTypeConcretizeRule {
35 fn name(&self) -> &str {
36 "JsonTypeConcretizeRule"
37 }
38
39 fn rewrite(
40 &self,
41 plan: LogicalPlan,
42 _config: &dyn OptimizerConfig,
43 ) -> Result<Transformed<LogicalPlan>> {
44 let json_types = deduce_json_types(&plan)?;
45 if json_types.is_empty() {
46 return Ok(Transformed::no(plan));
47 }
48
49 plan.transform_down(|plan| match &plan {
50 LogicalPlan::TableScan(table_scan) => {
51 let Some(source) = table_scan
52 .source
53 .as_any()
54 .downcast_ref::<DefaultTableSource>()
55 else {
56 return Ok(Transformed::no(plan));
57 };
58
59 let Some(adapter) = source
60 .table_provider
61 .as_any()
62 .downcast_ref::<DummyTableProvider>()
63 else {
64 return Ok(Transformed::no(plan));
65 };
66
67 adapter.with_json_type_hint(json_types.clone());
68 Ok(Transformed::yes(plan))
69 }
70 _ => Ok(Transformed::no(plan)),
71 })
72 }
73}
74
75fn deduce_json_types(plan: &LogicalPlan) -> Result<HashMap<String, JsonNativeType>> {
76 let mut json_types = HashMap::<String, JsonNativeType>::new();
77
78 plan.apply(|plan| {
79 for expr in plan.expressions() {
80 expr.apply(|expr| {
81 if let Some((column, json_type)) = deduce_json_type(expr)? {
82 json_types.entry(column).or_default().merge(&json_type);
83 }
84 Ok(TreeNodeRecursion::Continue)
85 })?;
86 }
87 Ok(TreeNodeRecursion::Continue)
88 })?;
89 Ok(json_types)
90}
91
92fn deduce_json_type(expr: &Expr) -> Result<Option<(String, JsonNativeType)>> {
93 let f = match expr {
94 Expr::ScalarFunction(f) if f.name().eq_ignore_ascii_case(JsonGetWithType::NAME) => f,
95 _ => return Ok(None),
96 };
97
98 let Some(Expr::Column(column)) = f.args.first() else {
99 return plan_err!(
100 "First argument of {} is expected to be a column expr, actual: {:?}",
101 JsonGetWithType::NAME,
102 f.args.first()
103 );
104 };
105
106 let Some(path) = f
107 .args
108 .get(1)
109 .and_then(|expr| expr.as_literal())
110 .and_then(|x| x.try_as_str())
111 .flatten()
112 else {
113 return plan_err!(
114 "Second argument of {} is expected to be a string literal, actual: {:?}",
115 JsonGetWithType::NAME,
116 f.args.get(1)
117 );
118 };
119
120 let with_type = f
121 .args
122 .get(2)
123 .and_then(|expr| expr.as_literal())
124 .map(|x| x.data_type())
125 .unwrap_or(DataType::Utf8View);
126 let with_type =
127 JsonNativeType::try_from(&with_type).map_err(|e| plan_datafusion_err!("{e:?}"))?;
128
129 let mut split = path.rsplit(".");
130 let Some(leaf) = split.next() else {
131 return Ok(Some((column.name.clone(), JsonNativeType::String)));
132 };
133
134 let mut object = JsonObjectType::new();
135 object.insert(leaf.to_string(), with_type);
136 let mut root = JsonNativeType::Object(object);
137
138 for s in split {
139 let mut object = JsonObjectType::new();
140 object.insert(s.to_string(), root);
141 root = JsonNativeType::Object(object);
142 }
143
144 Ok(Some((column.name.clone(), root)))
145}
146
147#[cfg(test)]
148mod tests {
149 use std::sync::Arc;
150
151 use common_function::scalars::udf::create_udf;
152 use datafusion::datasource::provider_as_source;
153 use datafusion_common::{Column, ScalarValue};
154 use datafusion_expr::expr::ScalarFunction;
155 use datafusion_expr::{LogicalPlanBuilder, col};
156 use datafusion_optimizer::OptimizerContext;
157 use store_api::storage::RegionId;
158
159 use super::*;
160 use crate::optimizer::test_util::mock_table_provider;
161
162 fn json_get_expr(base: Expr, path: Expr, with_type: Option<DataType>) -> Result<Expr> {
163 let json_get = Arc::new(create_udf(Arc::new(JsonGetWithType::default())));
164 let mut args = vec![base, path];
165 if let Some(with_type) = with_type {
166 let with_type = ScalarValue::try_new_null(&with_type)?;
167 args.push(Expr::Literal(with_type, None));
168 }
169 Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
170 json_get, args,
171 )))
172 }
173
174 fn path_expr(path: &str) -> Expr {
175 Expr::Literal(ScalarValue::Utf8(Some(path.to_string())), None)
176 }
177
178 fn build_plan(exprs: Vec<Expr>) -> Result<(Arc<DummyTableProvider>, LogicalPlan)> {
179 let provider = Arc::new(mock_table_provider(RegionId::new(1024, 1)));
180 let plan = LogicalPlanBuilder::scan("t", provider_as_source(provider.clone()), None)?
181 .project(exprs)?
182 .build()?;
183 Ok((provider, plan))
184 }
185
186 #[test]
187 fn test_json_type_concretize_rule_rewrite() -> Result<()> {
188 let exprs = vec![
189 json_get_expr(col("k0"), path_expr("a.b"), Some(DataType::Int64))?.alias("ab"),
190 json_get_expr(col("k0"), path_expr("a.c"), None)?.alias("ac"),
191 json_get_expr(col("k0"), path_expr("d"), Some(DataType::Boolean))?.alias("d"),
192 ];
193 let (provider, plan) = build_plan(exprs)?;
194
195 assert!(
196 JsonTypeConcretizeRule
197 .rewrite(plan, &OptimizerContext::default())?
198 .transformed
199 );
200
201 let expected = JsonNativeType::Object(JsonObjectType::from([
202 (
203 "a".to_string(),
204 JsonNativeType::Object(JsonObjectType::from([
205 ("b".to_string(), JsonNativeType::i64()),
206 ("c".to_string(), JsonNativeType::String),
207 ])),
208 ),
209 ("d".to_string(), JsonNativeType::Bool),
210 ]));
211
212 let request = provider.scan_request();
213 assert_eq!(1, request.json_type_hint.len());
214 assert_eq!(Some(&expected), request.json_type_hint.get("k0"));
215 Ok(())
216 }
217
218 #[test]
219 fn test_json_type_concretize_rule_conflict_to_variant() -> Result<()> {
220 let exprs = vec![
221 json_get_expr(col("k0"), path_expr("a"), Some(DataType::Int64))?.alias("a_num"),
222 json_get_expr(col("k0"), path_expr("a.b"), Some(DataType::Boolean))?.alias("a_obj"),
223 ];
224 let (provider, plan) = build_plan(exprs)?;
225
226 assert!(
227 JsonTypeConcretizeRule
228 .rewrite(plan, &OptimizerContext::default())?
229 .transformed
230 );
231
232 let expected = JsonNativeType::Object(JsonObjectType::from([(
233 "a".to_string(),
234 JsonNativeType::Variant,
235 )]));
236 assert_eq!(
237 Some(&expected),
238 provider.scan_request().json_type_hint.get("k0")
239 );
240 Ok(())
241 }
242
243 #[test]
244 fn test_json_type_concretize_rule_no_json_get() -> Result<()> {
245 let (provider, plan) = build_plan(vec![col("k0"), col("v0")])?;
246
247 assert!(
248 !JsonTypeConcretizeRule
249 .rewrite(plan, &OptimizerContext::default())?
250 .transformed
251 );
252 assert!(provider.scan_request().json_type_hint.is_empty());
253 Ok(())
254 }
255
256 #[test]
257 fn test_deduce_json_type_with_non_column_base() -> Result<()> {
258 let expr = json_get_expr(
259 Expr::Literal(ScalarValue::Utf8(Some("{}".to_string())), None),
260 path_expr("a"),
261 Some(DataType::Int64),
262 )?;
263
264 let err = deduce_json_type(&expr).unwrap_err();
265 assert!(
266 err.to_string()
267 .contains("First argument of json_get is expected to be a column expr")
268 );
269 Ok(())
270 }
271
272 #[test]
273 fn test_deduce_json_type_with_non_literal_path() -> Result<()> {
274 let expr = json_get_expr(
275 Expr::Column(Column::new_unqualified("k0")),
276 Expr::Column(Column::new_unqualified("path_col")),
277 Some(DataType::Int64),
278 )?;
279
280 let err = deduce_json_type(&expr).unwrap_err();
281 assert!(
282 err.to_string()
283 .contains("Second argument of json_get is expected to be a string literal")
284 );
285 Ok(())
286 }
287
288 #[test]
289 fn test_deduce_json_type_default_string() -> Result<()> {
290 let expr = json_get_expr(
291 Expr::Column(Column::new_unqualified("k0")),
292 path_expr("a.b"),
293 None,
294 )?;
295
296 let deduced = deduce_json_type(&expr)?;
297 let expected = JsonNativeType::Object(JsonObjectType::from([(
298 "a".to_string(),
299 JsonNativeType::Object(JsonObjectType::from([(
300 "b".to_string(),
301 JsonNativeType::String,
302 )])),
303 )]));
304
305 assert_eq!(Some(("k0".to_string(), expected)), deduced);
306 Ok(())
307 }
308}