query/query_engine/
state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::{Arc, RwLock};
18
19use async_trait::async_trait;
20use catalog::CatalogManagerRef;
21use common_base::Plugins;
22use common_function::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
23use common_function::function_factory::ScalarFunctionFactory;
24use common_function::handlers::{
25    FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
26};
27use common_function::state::FunctionState;
28use common_telemetry::warn;
29use datafusion::catalog::TableFunction;
30use datafusion::dataframe::DataFrame;
31use datafusion::error::Result as DfResult;
32use datafusion::execution::SessionStateBuilder;
33use datafusion::execution::context::{QueryPlanner, SessionConfig, SessionContext, SessionState};
34use datafusion::execution::runtime_env::RuntimeEnv;
35use datafusion::physical_optimizer::PhysicalOptimizerRule;
36use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
37use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan;
38use datafusion::physical_plan::ExecutionPlan;
39use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
40use datafusion_expr::{AggregateUDF, LogicalPlan as DfLogicalPlan};
41use datafusion_optimizer::analyzer::Analyzer;
42use datafusion_optimizer::optimizer::Optimizer;
43use partition::manager::PartitionRuleManagerRef;
44use promql::extension_plan::PromExtensionPlanner;
45use table::TableRef;
46use table::table::adapter::DfTableProviderAdapter;
47
48use crate::QueryEngineContext;
49use crate::dist_plan::{
50    DistExtensionPlanner, DistPlannerAnalyzer, DistPlannerOptions, MergeSortExtensionPlanner,
51};
52use crate::optimizer::ExtensionAnalyzerRule;
53use crate::optimizer::constant_term::MatchesConstantTermOptimizer;
54use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
55use crate::optimizer::parallelize_scan::ParallelizeScan;
56use crate::optimizer::pass_distribution::PassDistribution;
57use crate::optimizer::remove_duplicate::RemoveDuplicate;
58use crate::optimizer::scan_hint::ScanHintRule;
59use crate::optimizer::string_normalization::StringNormalizationRule;
60use crate::optimizer::transcribe_atat::TranscribeAtatRule;
61use crate::optimizer::type_conversion::TypeConversionRule;
62use crate::optimizer::windowed_sort::WindowedSortPhysicalRule;
63use crate::options::QueryOptions as QueryOptionsNew;
64use crate::query_engine::DefaultSerializer;
65use crate::query_engine::options::QueryOptions;
66use crate::range_select::planner::RangeSelectPlanner;
67use crate::region_query::RegionQueryHandlerRef;
68
69/// Query engine global state
70#[derive(Clone)]
71pub struct QueryEngineState {
72    df_context: SessionContext,
73    catalog_manager: CatalogManagerRef,
74    function_state: Arc<FunctionState>,
75    scalar_functions: Arc<RwLock<HashMap<String, ScalarFunctionFactory>>>,
76    aggr_functions: Arc<RwLock<HashMap<String, AggregateUDF>>>,
77    table_functions: Arc<RwLock<HashMap<String, Arc<TableFunction>>>>,
78    extension_rules: Vec<Arc<dyn ExtensionAnalyzerRule + Send + Sync>>,
79    plugins: Plugins,
80}
81
82impl fmt::Debug for QueryEngineState {
83    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        f.debug_struct("QueryEngineState")
85            .field("state", &self.df_context.state())
86            .finish()
87    }
88}
89
90impl QueryEngineState {
91    #[allow(clippy::too_many_arguments)]
92    pub fn new(
93        catalog_list: CatalogManagerRef,
94        partition_rule_manager: Option<PartitionRuleManagerRef>,
95        region_query_handler: Option<RegionQueryHandlerRef>,
96        table_mutation_handler: Option<TableMutationHandlerRef>,
97        procedure_service_handler: Option<ProcedureServiceHandlerRef>,
98        flow_service_handler: Option<FlowServiceHandlerRef>,
99        with_dist_planner: bool,
100        plugins: Plugins,
101        options: QueryOptionsNew,
102    ) -> Self {
103        let runtime_env = Arc::new(RuntimeEnv::default());
104        let mut session_config = SessionConfig::new().with_create_default_catalog_and_schema(false);
105        if options.parallelism > 0 {
106            session_config = session_config.with_target_partitions(options.parallelism);
107        }
108        if options.allow_query_fallback {
109            session_config
110                .options_mut()
111                .extensions
112                .insert(DistPlannerOptions {
113                    allow_query_fallback: true,
114                });
115        }
116
117        // todo(hl): This serves as a workaround for https://github.com/GreptimeTeam/greptimedb/issues/5659
118        // and we can add that check back once we upgrade datafusion.
119        session_config
120            .options_mut()
121            .execution
122            .skip_physical_aggregate_schema_check = true;
123
124        // Apply extension rules
125        let mut extension_rules = Vec::new();
126
127        // The [`TypeConversionRule`] must be at first
128        extension_rules.insert(0, Arc::new(TypeConversionRule) as _);
129
130        // Apply the datafusion rules
131        let mut analyzer = Analyzer::new();
132        analyzer.rules.insert(0, Arc::new(TranscribeAtatRule));
133        analyzer.rules.insert(0, Arc::new(StringNormalizationRule));
134        analyzer
135            .rules
136            .insert(0, Arc::new(CountWildcardToTimeIndexRule));
137
138        if with_dist_planner {
139            analyzer.rules.push(Arc::new(DistPlannerAnalyzer));
140        }
141
142        analyzer.rules.push(Arc::new(FixStateUdafOrderingAnalyzer));
143
144        let mut optimizer = Optimizer::new();
145        optimizer.rules.push(Arc::new(ScanHintRule));
146
147        // add physical optimizer
148        let mut physical_optimizer = PhysicalOptimizer::new();
149        // Change TableScan's partition right before enforcing distribution
150        physical_optimizer
151            .rules
152            .insert(5, Arc::new(ParallelizeScan));
153        // Pass distribution requirement to MergeScanExec to avoid unnecessary shuffling
154        physical_optimizer
155            .rules
156            .insert(6, Arc::new(PassDistribution));
157        // Enforce sorting AFTER custom rules that modify the plan structure
158        physical_optimizer.rules.insert(
159            7,
160            Arc::new(datafusion::physical_optimizer::enforce_sorting::EnforceSorting {}),
161        );
162        // Add rule for windowed sort
163        physical_optimizer
164            .rules
165            .push(Arc::new(WindowedSortPhysicalRule));
166        physical_optimizer
167            .rules
168            .push(Arc::new(MatchesConstantTermOptimizer));
169        // Add rule to remove duplicate nodes generated by other rules. Run this in the last.
170        physical_optimizer.rules.push(Arc::new(RemoveDuplicate));
171        // Place SanityCheckPlan at the end of the list to ensure that it runs after all other rules.
172        Self::remove_physical_optimizer_rule(
173            &mut physical_optimizer.rules,
174            SanityCheckPlan {}.name(),
175        );
176        physical_optimizer.rules.push(Arc::new(SanityCheckPlan {}));
177
178        let session_state = SessionStateBuilder::new()
179            .with_config(session_config)
180            .with_runtime_env(runtime_env)
181            .with_default_features()
182            .with_analyzer_rules(analyzer.rules)
183            .with_serializer_registry(Arc::new(DefaultSerializer))
184            .with_query_planner(Arc::new(DfQueryPlanner::new(
185                catalog_list.clone(),
186                partition_rule_manager,
187                region_query_handler,
188            )))
189            .with_optimizer_rules(optimizer.rules)
190            .with_physical_optimizer_rules(physical_optimizer.rules)
191            .build();
192
193        let df_context = SessionContext::new_with_state(session_state);
194
195        Self {
196            df_context,
197            catalog_manager: catalog_list,
198            function_state: Arc::new(FunctionState {
199                table_mutation_handler,
200                procedure_service_handler,
201                flow_service_handler,
202            }),
203            aggr_functions: Arc::new(RwLock::new(HashMap::new())),
204            table_functions: Arc::new(RwLock::new(HashMap::new())),
205            extension_rules,
206            plugins,
207            scalar_functions: Arc::new(RwLock::new(HashMap::new())),
208        }
209    }
210
211    fn remove_physical_optimizer_rule(
212        rules: &mut Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
213        name: &str,
214    ) {
215        rules.retain(|rule| rule.name() != name);
216    }
217
218    /// Optimize the logical plan by the extension anayzer rules.
219    pub fn optimize_by_extension_rules(
220        &self,
221        plan: DfLogicalPlan,
222        context: &QueryEngineContext,
223    ) -> DfResult<DfLogicalPlan> {
224        self.extension_rules
225            .iter()
226            .try_fold(plan, |acc_plan, rule| {
227                rule.analyze(acc_plan, context, self.session_state().config_options())
228            })
229    }
230
231    /// Run the full logical plan optimize phase for the given plan.
232    pub fn optimize_logical_plan(&self, plan: DfLogicalPlan) -> DfResult<DfLogicalPlan> {
233        self.session_state().optimize(&plan)
234    }
235
236    /// Retrieve the scalar function by name
237    pub fn scalar_function(&self, function_name: &str) -> Option<ScalarFunctionFactory> {
238        self.scalar_functions
239            .read()
240            .unwrap()
241            .get(function_name)
242            .cloned()
243    }
244
245    /// Retrieve scalar function names.
246    pub fn scalar_names(&self) -> Vec<String> {
247        self.scalar_functions
248            .read()
249            .unwrap()
250            .keys()
251            .cloned()
252            .collect()
253    }
254
255    /// Retrieve the aggregate function by name
256    pub fn aggr_function(&self, function_name: &str) -> Option<AggregateUDF> {
257        self.aggr_functions
258            .read()
259            .unwrap()
260            .get(function_name)
261            .cloned()
262    }
263
264    /// Retrieve aggregate function names.
265    pub fn aggr_names(&self) -> Vec<String> {
266        self.aggr_functions
267            .read()
268            .unwrap()
269            .keys()
270            .cloned()
271            .collect()
272    }
273
274    /// Retrieve table function by name
275    pub fn table_function(&self, function_name: &str) -> Option<Arc<TableFunction>> {
276        self.table_functions
277            .read()
278            .unwrap()
279            .get(function_name)
280            .cloned()
281    }
282
283    /// Retrieve table function names.
284    pub fn table_function_names(&self) -> Vec<String> {
285        self.table_functions
286            .read()
287            .unwrap()
288            .keys()
289            .cloned()
290            .collect()
291    }
292
293    /// Register an scalar function.
294    /// Will override if the function with same name is already registered.
295    pub fn register_scalar_function(&self, func: ScalarFunctionFactory) {
296        let name = func.name().to_string();
297        let x = self
298            .scalar_functions
299            .write()
300            .unwrap()
301            .insert(name.clone(), func);
302
303        if x.is_some() {
304            warn!("Already registered scalar function '{name}'");
305        }
306    }
307
308    /// Register an aggregate function.
309    ///
310    /// # Panics
311    /// Will panic if the function with same name is already registered.
312    ///
313    /// Panicking consideration: currently the aggregated functions are all statically registered,
314    /// user cannot define their own aggregate functions on the fly. So we can panic here. If that
315    /// invariant is broken in the future, we should return an error instead of panicking.
316    pub fn register_aggr_function(&self, func: AggregateUDF) {
317        let name = func.name().to_string();
318        let x = self
319            .aggr_functions
320            .write()
321            .unwrap()
322            .insert(name.clone(), func);
323        assert!(
324            x.is_none(),
325            "Already registered aggregate function '{name}'"
326        );
327    }
328
329    pub fn register_table_function(&self, func: Arc<TableFunction>) {
330        let name = func.name();
331        let x = self
332            .table_functions
333            .write()
334            .unwrap()
335            .insert(name.to_string(), func.clone());
336
337        if x.is_some() {
338            warn!("Already registered table function '{name}");
339        }
340    }
341
342    pub fn catalog_manager(&self) -> &CatalogManagerRef {
343        &self.catalog_manager
344    }
345
346    pub fn function_state(&self) -> Arc<FunctionState> {
347        self.function_state.clone()
348    }
349
350    /// Returns the [`TableMutationHandlerRef`] in state.
351    pub fn table_mutation_handler(&self) -> Option<&TableMutationHandlerRef> {
352        self.function_state.table_mutation_handler.as_ref()
353    }
354
355    /// Returns the [`ProcedureServiceHandlerRef`] in state.
356    pub fn procedure_service_handler(&self) -> Option<&ProcedureServiceHandlerRef> {
357        self.function_state.procedure_service_handler.as_ref()
358    }
359
360    pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
361        self.plugins
362            .map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
363            .unwrap_or(false)
364    }
365
366    pub fn session_state(&self) -> SessionState {
367        self.df_context.state()
368    }
369
370    /// Create a DataFrame for a table
371    pub fn read_table(&self, table: TableRef) -> DfResult<DataFrame> {
372        self.df_context
373            .read_table(Arc::new(DfTableProviderAdapter::new(table)))
374    }
375}
376
377struct DfQueryPlanner {
378    physical_planner: DefaultPhysicalPlanner,
379}
380
381impl fmt::Debug for DfQueryPlanner {
382    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
383        f.debug_struct("DfQueryPlanner").finish()
384    }
385}
386
387#[async_trait]
388impl QueryPlanner for DfQueryPlanner {
389    async fn create_physical_plan(
390        &self,
391        logical_plan: &DfLogicalPlan,
392        session_state: &SessionState,
393    ) -> DfResult<Arc<dyn ExecutionPlan>> {
394        self.physical_planner
395            .create_physical_plan(logical_plan, session_state)
396            .await
397    }
398}
399
400impl DfQueryPlanner {
401    fn new(
402        catalog_manager: CatalogManagerRef,
403        partition_rule_manager: Option<PartitionRuleManagerRef>,
404        region_query_handler: Option<RegionQueryHandlerRef>,
405    ) -> Self {
406        let mut planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>> =
407            vec![Arc::new(PromExtensionPlanner), Arc::new(RangeSelectPlanner)];
408        if let (Some(region_query_handler), Some(partition_rule_manager)) =
409            (region_query_handler, partition_rule_manager)
410        {
411            planners.push(Arc::new(DistExtensionPlanner::new(
412                catalog_manager,
413                partition_rule_manager,
414                region_query_handler,
415            )));
416            planners.push(Arc::new(MergeSortExtensionPlanner {}));
417        }
418        Self {
419            physical_planner: DefaultPhysicalPlanner::with_extension_planners(planners),
420        }
421    }
422}