query/query_engine/
state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::{Arc, RwLock};
18
19use async_trait::async_trait;
20use catalog::CatalogManagerRef;
21use common_base::Plugins;
22use common_function::function_factory::ScalarFunctionFactory;
23use common_function::handlers::{
24    FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
25};
26use common_function::state::FunctionState;
27use common_telemetry::warn;
28use datafusion::catalog::TableFunction;
29use datafusion::dataframe::DataFrame;
30use datafusion::error::Result as DfResult;
31use datafusion::execution::SessionStateBuilder;
32use datafusion::execution::context::{QueryPlanner, SessionConfig, SessionContext, SessionState};
33use datafusion::execution::runtime_env::RuntimeEnv;
34use datafusion::physical_optimizer::PhysicalOptimizerRule;
35use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
36use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan;
37use datafusion::physical_plan::ExecutionPlan;
38use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
39use datafusion_expr::{AggregateUDF, LogicalPlan as DfLogicalPlan};
40use datafusion_optimizer::analyzer::Analyzer;
41use datafusion_optimizer::optimizer::Optimizer;
42use partition::manager::PartitionRuleManagerRef;
43use promql::extension_plan::PromExtensionPlanner;
44use table::TableRef;
45use table::table::adapter::DfTableProviderAdapter;
46
47use crate::QueryEngineContext;
48use crate::dist_plan::{
49    DistExtensionPlanner, DistPlannerAnalyzer, DistPlannerOptions, MergeSortExtensionPlanner,
50};
51use crate::optimizer::ExtensionAnalyzerRule;
52use crate::optimizer::constant_term::MatchesConstantTermOptimizer;
53use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
54use crate::optimizer::parallelize_scan::ParallelizeScan;
55use crate::optimizer::pass_distribution::PassDistribution;
56use crate::optimizer::remove_duplicate::RemoveDuplicate;
57use crate::optimizer::scan_hint::ScanHintRule;
58use crate::optimizer::string_normalization::StringNormalizationRule;
59use crate::optimizer::transcribe_atat::TranscribeAtatRule;
60use crate::optimizer::type_conversion::TypeConversionRule;
61use crate::optimizer::windowed_sort::WindowedSortPhysicalRule;
62use crate::options::QueryOptions as QueryOptionsNew;
63use crate::query_engine::DefaultSerializer;
64use crate::query_engine::options::QueryOptions;
65use crate::range_select::planner::RangeSelectPlanner;
66use crate::region_query::RegionQueryHandlerRef;
67
68/// Query engine global state
69#[derive(Clone)]
70pub struct QueryEngineState {
71    df_context: SessionContext,
72    catalog_manager: CatalogManagerRef,
73    function_state: Arc<FunctionState>,
74    scalar_functions: Arc<RwLock<HashMap<String, ScalarFunctionFactory>>>,
75    aggr_functions: Arc<RwLock<HashMap<String, AggregateUDF>>>,
76    table_functions: Arc<RwLock<HashMap<String, Arc<TableFunction>>>>,
77    extension_rules: Vec<Arc<dyn ExtensionAnalyzerRule + Send + Sync>>,
78    plugins: Plugins,
79}
80
81impl fmt::Debug for QueryEngineState {
82    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        f.debug_struct("QueryEngineState")
84            .field("state", &self.df_context.state())
85            .finish()
86    }
87}
88
89impl QueryEngineState {
90    #[allow(clippy::too_many_arguments)]
91    pub fn new(
92        catalog_list: CatalogManagerRef,
93        partition_rule_manager: Option<PartitionRuleManagerRef>,
94        region_query_handler: Option<RegionQueryHandlerRef>,
95        table_mutation_handler: Option<TableMutationHandlerRef>,
96        procedure_service_handler: Option<ProcedureServiceHandlerRef>,
97        flow_service_handler: Option<FlowServiceHandlerRef>,
98        with_dist_planner: bool,
99        plugins: Plugins,
100        options: QueryOptionsNew,
101    ) -> Self {
102        let runtime_env = Arc::new(RuntimeEnv::default());
103        let mut session_config = SessionConfig::new().with_create_default_catalog_and_schema(false);
104        if options.parallelism > 0 {
105            session_config = session_config.with_target_partitions(options.parallelism);
106        }
107        if options.allow_query_fallback {
108            session_config
109                .options_mut()
110                .extensions
111                .insert(DistPlannerOptions {
112                    allow_query_fallback: true,
113                });
114        }
115
116        // todo(hl): This serves as a workaround for https://github.com/GreptimeTeam/greptimedb/issues/5659
117        // and we can add that check back once we upgrade datafusion.
118        session_config
119            .options_mut()
120            .execution
121            .skip_physical_aggregate_schema_check = true;
122
123        // Apply extension rules
124        let mut extension_rules = Vec::new();
125
126        // The [`TypeConversionRule`] must be at first
127        extension_rules.insert(0, Arc::new(TypeConversionRule) as _);
128
129        // Apply the datafusion rules
130        let mut analyzer = Analyzer::new();
131        analyzer.rules.insert(0, Arc::new(TranscribeAtatRule));
132        analyzer.rules.insert(0, Arc::new(StringNormalizationRule));
133        analyzer
134            .rules
135            .insert(0, Arc::new(CountWildcardToTimeIndexRule));
136
137        if with_dist_planner {
138            analyzer.rules.push(Arc::new(DistPlannerAnalyzer));
139        }
140
141        let mut optimizer = Optimizer::new();
142        optimizer.rules.push(Arc::new(ScanHintRule));
143
144        // add physical optimizer
145        let mut physical_optimizer = PhysicalOptimizer::new();
146        // Change TableScan's partition right before enforcing distribution
147        physical_optimizer
148            .rules
149            .insert(5, Arc::new(ParallelizeScan));
150        // Pass distribution requirement to MergeScanExec to avoid unnecessary shuffling
151        physical_optimizer
152            .rules
153            .insert(6, Arc::new(PassDistribution));
154        // Enforce sorting AFTER custom rules that modify the plan structure
155        physical_optimizer.rules.insert(
156            7,
157            Arc::new(datafusion::physical_optimizer::enforce_sorting::EnforceSorting {}),
158        );
159        // Add rule for windowed sort
160        physical_optimizer
161            .rules
162            .push(Arc::new(WindowedSortPhysicalRule));
163        physical_optimizer
164            .rules
165            .push(Arc::new(MatchesConstantTermOptimizer));
166        // Add rule to remove duplicate nodes generated by other rules. Run this in the last.
167        physical_optimizer.rules.push(Arc::new(RemoveDuplicate));
168        // Place SanityCheckPlan at the end of the list to ensure that it runs after all other rules.
169        Self::remove_physical_optimizer_rule(
170            &mut physical_optimizer.rules,
171            SanityCheckPlan {}.name(),
172        );
173        physical_optimizer.rules.push(Arc::new(SanityCheckPlan {}));
174
175        let session_state = SessionStateBuilder::new()
176            .with_config(session_config)
177            .with_runtime_env(runtime_env)
178            .with_default_features()
179            .with_analyzer_rules(analyzer.rules)
180            .with_serializer_registry(Arc::new(DefaultSerializer))
181            .with_query_planner(Arc::new(DfQueryPlanner::new(
182                catalog_list.clone(),
183                partition_rule_manager,
184                region_query_handler,
185            )))
186            .with_optimizer_rules(optimizer.rules)
187            .with_physical_optimizer_rules(physical_optimizer.rules)
188            .build();
189
190        let df_context = SessionContext::new_with_state(session_state);
191
192        Self {
193            df_context,
194            catalog_manager: catalog_list,
195            function_state: Arc::new(FunctionState {
196                table_mutation_handler,
197                procedure_service_handler,
198                flow_service_handler,
199            }),
200            aggr_functions: Arc::new(RwLock::new(HashMap::new())),
201            table_functions: Arc::new(RwLock::new(HashMap::new())),
202            extension_rules,
203            plugins,
204            scalar_functions: Arc::new(RwLock::new(HashMap::new())),
205        }
206    }
207
208    fn remove_physical_optimizer_rule(
209        rules: &mut Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
210        name: &str,
211    ) {
212        rules.retain(|rule| rule.name() != name);
213    }
214
215    /// Optimize the logical plan by the extension anayzer rules.
216    pub fn optimize_by_extension_rules(
217        &self,
218        plan: DfLogicalPlan,
219        context: &QueryEngineContext,
220    ) -> DfResult<DfLogicalPlan> {
221        self.extension_rules
222            .iter()
223            .try_fold(plan, |acc_plan, rule| {
224                rule.analyze(acc_plan, context, self.session_state().config_options())
225            })
226    }
227
228    /// Run the full logical plan optimize phase for the given plan.
229    pub fn optimize_logical_plan(&self, plan: DfLogicalPlan) -> DfResult<DfLogicalPlan> {
230        self.session_state().optimize(&plan)
231    }
232
233    /// Retrieve the scalar function by name
234    pub fn scalar_function(&self, function_name: &str) -> Option<ScalarFunctionFactory> {
235        self.scalar_functions
236            .read()
237            .unwrap()
238            .get(function_name)
239            .cloned()
240    }
241
242    /// Retrieve scalar function names.
243    pub fn scalar_names(&self) -> Vec<String> {
244        self.scalar_functions
245            .read()
246            .unwrap()
247            .keys()
248            .cloned()
249            .collect()
250    }
251
252    /// Retrieve the aggregate function by name
253    pub fn aggr_function(&self, function_name: &str) -> Option<AggregateUDF> {
254        self.aggr_functions
255            .read()
256            .unwrap()
257            .get(function_name)
258            .cloned()
259    }
260
261    /// Retrieve aggregate function names.
262    pub fn aggr_names(&self) -> Vec<String> {
263        self.aggr_functions
264            .read()
265            .unwrap()
266            .keys()
267            .cloned()
268            .collect()
269    }
270
271    /// Retrieve table function by name
272    pub fn table_function(&self, function_name: &str) -> Option<Arc<TableFunction>> {
273        self.table_functions
274            .read()
275            .unwrap()
276            .get(function_name)
277            .cloned()
278    }
279
280    /// Retrieve table function names.
281    pub fn table_function_names(&self) -> Vec<String> {
282        self.table_functions
283            .read()
284            .unwrap()
285            .keys()
286            .cloned()
287            .collect()
288    }
289
290    /// Register an scalar function.
291    /// Will override if the function with same name is already registered.
292    pub fn register_scalar_function(&self, func: ScalarFunctionFactory) {
293        let name = func.name().to_string();
294        let x = self
295            .scalar_functions
296            .write()
297            .unwrap()
298            .insert(name.clone(), func);
299
300        if x.is_some() {
301            warn!("Already registered scalar function '{name}'");
302        }
303    }
304
305    /// Register an aggregate function.
306    ///
307    /// # Panics
308    /// Will panic if the function with same name is already registered.
309    ///
310    /// Panicking consideration: currently the aggregated functions are all statically registered,
311    /// user cannot define their own aggregate functions on the fly. So we can panic here. If that
312    /// invariant is broken in the future, we should return an error instead of panicking.
313    pub fn register_aggr_function(&self, func: AggregateUDF) {
314        let name = func.name().to_string();
315        let x = self
316            .aggr_functions
317            .write()
318            .unwrap()
319            .insert(name.clone(), func);
320        assert!(
321            x.is_none(),
322            "Already registered aggregate function '{name}'"
323        );
324    }
325
326    pub fn register_table_function(&self, func: Arc<TableFunction>) {
327        let name = func.name();
328        let x = self
329            .table_functions
330            .write()
331            .unwrap()
332            .insert(name.to_string(), func.clone());
333
334        if x.is_some() {
335            warn!("Already registered table function '{name}");
336        }
337    }
338
339    pub fn catalog_manager(&self) -> &CatalogManagerRef {
340        &self.catalog_manager
341    }
342
343    pub fn function_state(&self) -> Arc<FunctionState> {
344        self.function_state.clone()
345    }
346
347    /// Returns the [`TableMutationHandlerRef`] in state.
348    pub fn table_mutation_handler(&self) -> Option<&TableMutationHandlerRef> {
349        self.function_state.table_mutation_handler.as_ref()
350    }
351
352    /// Returns the [`ProcedureServiceHandlerRef`] in state.
353    pub fn procedure_service_handler(&self) -> Option<&ProcedureServiceHandlerRef> {
354        self.function_state.procedure_service_handler.as_ref()
355    }
356
357    pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
358        self.plugins
359            .map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
360            .unwrap_or(false)
361    }
362
363    pub fn session_state(&self) -> SessionState {
364        self.df_context.state()
365    }
366
367    /// Create a DataFrame for a table
368    pub fn read_table(&self, table: TableRef) -> DfResult<DataFrame> {
369        self.df_context
370            .read_table(Arc::new(DfTableProviderAdapter::new(table)))
371    }
372}
373
374struct DfQueryPlanner {
375    physical_planner: DefaultPhysicalPlanner,
376}
377
378impl fmt::Debug for DfQueryPlanner {
379    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
380        f.debug_struct("DfQueryPlanner").finish()
381    }
382}
383
384#[async_trait]
385impl QueryPlanner for DfQueryPlanner {
386    async fn create_physical_plan(
387        &self,
388        logical_plan: &DfLogicalPlan,
389        session_state: &SessionState,
390    ) -> DfResult<Arc<dyn ExecutionPlan>> {
391        self.physical_planner
392            .create_physical_plan(logical_plan, session_state)
393            .await
394    }
395}
396
397impl DfQueryPlanner {
398    fn new(
399        catalog_manager: CatalogManagerRef,
400        partition_rule_manager: Option<PartitionRuleManagerRef>,
401        region_query_handler: Option<RegionQueryHandlerRef>,
402    ) -> Self {
403        let mut planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>> =
404            vec![Arc::new(PromExtensionPlanner), Arc::new(RangeSelectPlanner)];
405        if let (Some(region_query_handler), Some(partition_rule_manager)) =
406            (region_query_handler, partition_rule_manager)
407        {
408            planners.push(Arc::new(DistExtensionPlanner::new(
409                catalog_manager,
410                partition_rule_manager,
411                region_query_handler,
412            )));
413            planners.push(Arc::new(MergeSortExtensionPlanner {}));
414        }
415        Self {
416            physical_planner: DefaultPhysicalPlanner::with_extension_planners(planners),
417        }
418    }
419}