Skip to main content

query/query_engine/
state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::num::NonZeroUsize;
18use std::sync::{Arc, RwLock};
19
20use async_trait::async_trait;
21use catalog::CatalogManagerRef;
22use common_base::Plugins;
23use common_function::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
24use common_function::function_factory::ScalarFunctionFactory;
25use common_function::function_registry::FUNCTION_REGISTRY;
26use common_function::handlers::{
27    FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
28};
29use common_function::state::FunctionState;
30use common_stat::get_total_memory_bytes;
31use common_telemetry::warn;
32use datafusion::catalog::TableFunction;
33use datafusion::dataframe::DataFrame;
34use datafusion::error::Result as DfResult;
35use datafusion::execution::SessionStateBuilder;
36use datafusion::execution::context::{QueryPlanner, SessionConfig, SessionContext, SessionState};
37use datafusion::execution::memory_pool::{
38    GreedyMemoryPool, MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation,
39    TrackConsumersPool,
40};
41use datafusion::execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder};
42use datafusion::physical_optimizer::PhysicalOptimizerRule;
43use datafusion::physical_optimizer::optimizer::PhysicalOptimizer;
44use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan;
45use datafusion::physical_plan::ExecutionPlan;
46use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
47use datafusion_expr::{AggregateUDF, LogicalPlan as DfLogicalPlan, WindowUDF};
48use datafusion_optimizer::Analyzer;
49use datafusion_optimizer::analyzer::function_rewrite::ApplyFunctionRewrites;
50use datafusion_optimizer::optimizer::Optimizer;
51use partition::manager::PartitionRuleManagerRef;
52use promql::extension_plan::PromExtensionPlanner;
53use table::TableRef;
54use table::table::adapter::DfTableProviderAdapter;
55
56use crate::QueryEngineContext;
57use crate::dist_plan::{
58    DistExtensionPlanner, DistPlannerAnalyzer, DistPlannerOptions, MergeSortExtensionPlanner,
59};
60use crate::metrics::{QUERY_MEMORY_POOL_REJECTED_TOTAL, QUERY_MEMORY_POOL_USAGE_BYTES};
61use crate::optimizer::ExtensionAnalyzerRule;
62use crate::optimizer::constant_term::MatchesConstantTermOptimizer;
63use crate::optimizer::count_nest_aggr::CountNestAggrRule;
64use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
65use crate::optimizer::parallelize_scan::ParallelizeScan;
66use crate::optimizer::pass_distribution::PassDistribution;
67use crate::optimizer::remove_duplicate::RemoveDuplicate;
68use crate::optimizer::scan_hint::ScanHintRule;
69use crate::optimizer::string_normalization::StringNormalizationRule;
70use crate::optimizer::transcribe_atat::TranscribeAtatRule;
71use crate::optimizer::type_conversion::TypeConversionRule;
72use crate::optimizer::windowed_sort::WindowedSortPhysicalRule;
73use crate::options::QueryOptions as QueryOptionsNew;
74use crate::query_engine::DefaultSerializer;
75use crate::query_engine::options::QueryOptions;
76use crate::range_select::planner::RangeSelectPlanner;
77use crate::region_query::RegionQueryHandlerRef;
78
79/// Query engine global state
80#[derive(Clone)]
81pub struct QueryEngineState {
82    df_context: SessionContext,
83    catalog_manager: CatalogManagerRef,
84    function_state: Arc<FunctionState>,
85    scalar_functions: Arc<RwLock<HashMap<String, ScalarFunctionFactory>>>,
86    aggr_functions: Arc<RwLock<HashMap<String, AggregateUDF>>>,
87    table_functions: Arc<RwLock<HashMap<String, Arc<TableFunction>>>>,
88    extension_rules: Vec<Arc<dyn ExtensionAnalyzerRule + Send + Sync>>,
89    plugins: Plugins,
90}
91
92impl fmt::Debug for QueryEngineState {
93    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94        f.debug_struct("QueryEngineState")
95            .field("state", &self.df_context.state())
96            .finish()
97    }
98}
99
100impl QueryEngineState {
101    #[allow(clippy::too_many_arguments)]
102    pub fn new(
103        catalog_list: CatalogManagerRef,
104        partition_rule_manager: Option<PartitionRuleManagerRef>,
105        region_query_handler: Option<RegionQueryHandlerRef>,
106        table_mutation_handler: Option<TableMutationHandlerRef>,
107        procedure_service_handler: Option<ProcedureServiceHandlerRef>,
108        flow_service_handler: Option<FlowServiceHandlerRef>,
109        with_dist_planner: bool,
110        plugins: Plugins,
111        options: QueryOptionsNew,
112    ) -> Self {
113        let total_memory = get_total_memory_bytes().max(0) as u64;
114        let memory_pool_size = options.memory_pool_size.resolve(total_memory) as usize;
115        let runtime_env = if memory_pool_size > 0 {
116            Arc::new(
117                RuntimeEnvBuilder::new()
118                    .with_memory_pool(Arc::new(MetricsMemoryPool::new(memory_pool_size)))
119                    .build()
120                    .expect("Failed to build RuntimeEnv"),
121            )
122        } else {
123            Arc::new(RuntimeEnv::default())
124        };
125        let mut session_config = SessionConfig::new().with_create_default_catalog_and_schema(false);
126        if options.parallelism > 0 {
127            session_config = session_config.with_target_partitions(options.parallelism);
128        }
129        if options.allow_query_fallback {
130            session_config
131                .options_mut()
132                .extensions
133                .insert(DistPlannerOptions {
134                    allow_query_fallback: true,
135                });
136        }
137
138        // todo(hl): This serves as a workaround for https://github.com/GreptimeTeam/greptimedb/issues/5659
139        // and we can add that check back once we upgrade datafusion.
140        session_config
141            .options_mut()
142            .execution
143            .skip_physical_aggregate_schema_check = true;
144
145        // Apply extension rules
146        let mut extension_rules = Vec::new();
147
148        // The [`TypeConversionRule`] must be at first
149        extension_rules.insert(0, Arc::new(TypeConversionRule) as _);
150        extension_rules.push(Arc::new(CountNestAggrRule) as _);
151
152        // Apply the datafusion rules
153        let mut analyzer = Analyzer::new();
154        analyzer.rules.insert(0, Arc::new(TranscribeAtatRule));
155        analyzer.rules.insert(0, Arc::new(StringNormalizationRule));
156        analyzer
157            .rules
158            .insert(0, Arc::new(CountWildcardToTimeIndexRule));
159
160        // Add ApplyFunctionRewrites rule,
161        // Note we cannot use `analyzer.add_function_rewrite`
162        // because only rules are copied into session_state
163        analyzer.rules.insert(
164            0,
165            Arc::new(ApplyFunctionRewrites::new(
166                FUNCTION_REGISTRY.function_rewrites(),
167            )),
168        );
169
170        if with_dist_planner {
171            analyzer.rules.push(Arc::new(DistPlannerAnalyzer));
172        }
173        analyzer.rules.push(Arc::new(FixStateUdafOrderingAnalyzer));
174
175        let mut optimizer = Optimizer::new();
176        optimizer.rules.push(Arc::new(ScanHintRule));
177
178        // add physical optimizer
179        let mut physical_optimizer = PhysicalOptimizer::new();
180        // Change TableScan's partition right before enforcing distribution
181        physical_optimizer
182            .rules
183            .insert(5, Arc::new(ParallelizeScan));
184        // Pass distribution requirement to MergeScanExec to avoid unnecessary shuffling
185        physical_optimizer
186            .rules
187            .insert(6, Arc::new(PassDistribution));
188        // Enforce sorting AFTER custom rules that modify the plan structure
189        physical_optimizer.rules.insert(
190            7,
191            Arc::new(datafusion::physical_optimizer::enforce_sorting::EnforceSorting {}),
192        );
193        // Add rule for windowed sort
194        physical_optimizer
195            .rules
196            .push(Arc::new(WindowedSortPhysicalRule));
197        // explicitly not do filter pushdown for windowed sort&part sort
198        // (notice that `PartSortExec` create another new dyn filter that need to be pushdown if want to use dyn filter optimization)
199        // benchmark shows it can cause performance regression due to useless filtering and extra shuffle.
200        // We can add a rule to do filter pushdown for windowed sort in the future if we find a way to avoid the performance regression.
201        physical_optimizer
202            .rules
203            .push(Arc::new(MatchesConstantTermOptimizer));
204        // Add rule to remove duplicate nodes generated by other rules. Run this in the last.
205        physical_optimizer.rules.push(Arc::new(RemoveDuplicate));
206        // Place SanityCheckPlan at the end of the list to ensure that it runs after all other rules.
207        Self::remove_physical_optimizer_rule(
208            &mut physical_optimizer.rules,
209            SanityCheckPlan {}.name(),
210        );
211        physical_optimizer.rules.push(Arc::new(SanityCheckPlan {}));
212
213        let session_state = SessionStateBuilder::new()
214            .with_config(session_config)
215            .with_runtime_env(runtime_env)
216            .with_default_features()
217            .with_analyzer_rules(analyzer.rules)
218            .with_serializer_registry(Arc::new(DefaultSerializer))
219            .with_query_planner(Arc::new(DfQueryPlanner::new(
220                catalog_list.clone(),
221                partition_rule_manager,
222                region_query_handler,
223            )))
224            .with_optimizer_rules(optimizer.rules)
225            .with_physical_optimizer_rules(physical_optimizer.rules)
226            .build();
227
228        let df_context = SessionContext::new_with_state(session_state);
229        register_function_aliases(&df_context);
230
231        Self {
232            df_context,
233            catalog_manager: catalog_list,
234            function_state: Arc::new(FunctionState {
235                table_mutation_handler,
236                procedure_service_handler,
237                flow_service_handler,
238            }),
239            aggr_functions: Arc::new(RwLock::new(HashMap::new())),
240            table_functions: Arc::new(RwLock::new(HashMap::new())),
241            extension_rules,
242            plugins,
243            scalar_functions: Arc::new(RwLock::new(HashMap::new())),
244        }
245    }
246
247    fn remove_physical_optimizer_rule(
248        rules: &mut Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>>,
249        name: &str,
250    ) {
251        rules.retain(|rule| rule.name() != name);
252    }
253
254    /// Optimize the logical plan by the extension analyzer rules.
255    pub fn optimize_by_extension_rules(
256        &self,
257        plan: DfLogicalPlan,
258        context: &QueryEngineContext,
259    ) -> DfResult<DfLogicalPlan> {
260        self.extension_rules
261            .iter()
262            .try_fold(plan, |acc_plan, rule| {
263                rule.analyze(acc_plan, context, self.session_state().config_options())
264            })
265    }
266
267    /// Run the full logical plan optimize phase for the given plan.
268    pub fn optimize_logical_plan(&self, plan: DfLogicalPlan) -> DfResult<DfLogicalPlan> {
269        self.session_state().optimize(&plan)
270    }
271
272    /// Retrieve the scalar function by name
273    pub fn scalar_function(&self, function_name: &str) -> Option<ScalarFunctionFactory> {
274        self.scalar_functions
275            .read()
276            .unwrap()
277            .get(function_name)
278            .cloned()
279    }
280
281    /// Retrieve scalar function names.
282    pub fn scalar_names(&self) -> Vec<String> {
283        self.scalar_functions
284            .read()
285            .unwrap()
286            .keys()
287            .cloned()
288            .collect()
289    }
290
291    /// Retrieve the aggregate function by name
292    pub fn aggr_function(&self, function_name: &str) -> Option<AggregateUDF> {
293        self.aggr_functions
294            .read()
295            .unwrap()
296            .get(function_name)
297            .cloned()
298    }
299
300    /// Retrieve aggregate function names.
301    pub fn aggr_names(&self) -> Vec<String> {
302        self.aggr_functions
303            .read()
304            .unwrap()
305            .keys()
306            .cloned()
307            .collect()
308    }
309
310    /// Retrieve table function by name
311    pub fn table_function(&self, function_name: &str) -> Option<Arc<TableFunction>> {
312        self.table_functions
313            .read()
314            .unwrap()
315            .get(function_name)
316            .cloned()
317    }
318
319    /// Retrieve table function names.
320    pub fn table_function_names(&self) -> Vec<String> {
321        self.table_functions
322            .read()
323            .unwrap()
324            .keys()
325            .cloned()
326            .collect()
327    }
328
329    /// Register an scalar function.
330    /// Will override if the function with same name is already registered.
331    pub fn register_scalar_function(&self, func: ScalarFunctionFactory) {
332        let name = func.name().to_string();
333        let x = self
334            .scalar_functions
335            .write()
336            .unwrap()
337            .insert(name.clone(), func);
338
339        if x.is_some() {
340            warn!("Already registered scalar function '{name}'");
341        }
342    }
343
344    /// Register an aggregate function.
345    ///
346    /// # Panics
347    /// Will panic if the function with same name is already registered.
348    ///
349    /// Panicking consideration: currently the aggregated functions are all statically registered,
350    /// user cannot define their own aggregate functions on the fly. So we can panic here. If that
351    /// invariant is broken in the future, we should return an error instead of panicking.
352    pub fn register_aggr_function(&self, func: AggregateUDF) {
353        let name = func.name().to_string();
354        let x = self
355            .aggr_functions
356            .write()
357            .unwrap()
358            .insert(name.clone(), func);
359        assert!(
360            x.is_none(),
361            "Already registered aggregate function '{name}'"
362        );
363    }
364
365    pub fn register_table_function(&self, func: Arc<TableFunction>) {
366        let name = func.name();
367        let x = self
368            .table_functions
369            .write()
370            .unwrap()
371            .insert(name.to_string(), func.clone());
372
373        if x.is_some() {
374            warn!("Already registered table function '{name}'");
375        }
376    }
377
378    /// Register a window function (UDWF) directly on the DataFusion SessionContext.
379    ///
380    /// This makes the function visible via `session_state.window_functions()`,
381    /// which is used by `DfContextProviderAdapter::get_window_meta`.
382    pub fn register_window_function(&self, func: WindowUDF) {
383        self.df_context.register_udwf(func);
384    }
385
386    pub fn catalog_manager(&self) -> &CatalogManagerRef {
387        &self.catalog_manager
388    }
389
390    pub fn function_state(&self) -> Arc<FunctionState> {
391        self.function_state.clone()
392    }
393
394    /// Returns the [`TableMutationHandlerRef`] in state.
395    pub fn table_mutation_handler(&self) -> Option<&TableMutationHandlerRef> {
396        self.function_state.table_mutation_handler.as_ref()
397    }
398
399    /// Returns the [`ProcedureServiceHandlerRef`] in state.
400    pub fn procedure_service_handler(&self) -> Option<&ProcedureServiceHandlerRef> {
401        self.function_state.procedure_service_handler.as_ref()
402    }
403
404    pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
405        self.plugins
406            .map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
407            .unwrap_or(false)
408    }
409
410    pub fn session_state(&self) -> SessionState {
411        self.df_context.state()
412    }
413
414    /// Create a DataFrame for a table
415    pub fn read_table(&self, table: TableRef) -> DfResult<DataFrame> {
416        self.df_context
417            .read_table(Arc::new(DfTableProviderAdapter::new(table)))
418    }
419}
420
421struct DfQueryPlanner {
422    physical_planner: DefaultPhysicalPlanner,
423}
424
425impl fmt::Debug for DfQueryPlanner {
426    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
427        f.debug_struct("DfQueryPlanner").finish()
428    }
429}
430
431#[async_trait]
432impl QueryPlanner for DfQueryPlanner {
433    async fn create_physical_plan(
434        &self,
435        logical_plan: &DfLogicalPlan,
436        session_state: &SessionState,
437    ) -> DfResult<Arc<dyn ExecutionPlan>> {
438        self.physical_planner
439            .create_physical_plan(logical_plan, session_state)
440            .await
441    }
442}
443
444/// MySQL-compatible scalar function aliases: (target_name, alias)
445const SCALAR_FUNCTION_ALIASES: &[(&str, &str)] = &[
446    ("upper", "ucase"),
447    ("lower", "lcase"),
448    ("ceil", "ceiling"),
449    ("substr", "mid"),
450    ("random", "rand"),
451];
452
453/// MySQL-compatible aggregate function aliases: (target_name, alias)
454const AGGREGATE_FUNCTION_ALIASES: &[(&str, &str)] =
455    &[("stddev_pop", "std"), ("var_pop", "variance")];
456
457/// Register function aliases.
458///
459/// This function adds aliases like `ucase` -> `upper`, `lcase` -> `lower`, etc.
460/// to make GreptimeDB more compatible with MySQL syntax.
461fn register_function_aliases(ctx: &SessionContext) {
462    let state = ctx.state();
463
464    for (target, alias) in SCALAR_FUNCTION_ALIASES {
465        if let Some(func) = state.scalar_functions().get(*target) {
466            let aliased = func.as_ref().clone().with_aliases([*alias]);
467            ctx.register_udf(aliased);
468        }
469    }
470
471    for (target, alias) in AGGREGATE_FUNCTION_ALIASES {
472        if let Some(func) = state.aggregate_functions().get(*target) {
473            let aliased = func.as_ref().clone().with_aliases([*alias]);
474            ctx.register_udaf(aliased);
475        }
476    }
477}
478
479impl DfQueryPlanner {
480    fn new(
481        catalog_manager: CatalogManagerRef,
482        partition_rule_manager: Option<PartitionRuleManagerRef>,
483        region_query_handler: Option<RegionQueryHandlerRef>,
484    ) -> Self {
485        let mut planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>> =
486            vec![Arc::new(PromExtensionPlanner), Arc::new(RangeSelectPlanner)];
487        if let (Some(region_query_handler), Some(partition_rule_manager)) =
488            (region_query_handler, partition_rule_manager)
489        {
490            planners.push(Arc::new(DistExtensionPlanner::new(
491                catalog_manager,
492                partition_rule_manager,
493                region_query_handler,
494            )));
495            planners.push(Arc::new(MergeSortExtensionPlanner {}));
496        }
497        Self {
498            physical_planner: DefaultPhysicalPlanner::with_extension_planners(planners),
499        }
500    }
501}
502
503/// A wrapper around TrackConsumersPool that records metrics.
504///
505/// This wrapper intercepts all memory pool operations and updates
506/// Prometheus metrics for monitoring query memory usage and rejections.
507#[derive(Debug)]
508struct MetricsMemoryPool {
509    inner: Arc<TrackConsumersPool<GreedyMemoryPool>>,
510}
511
512impl MetricsMemoryPool {
513    // Number of top memory consumers to report in OOM error messages
514    const TOP_CONSUMERS_TO_REPORT: usize = 5;
515
516    fn new(limit: usize) -> Self {
517        Self {
518            inner: Arc::new(TrackConsumersPool::new(
519                GreedyMemoryPool::new(limit),
520                NonZeroUsize::new(Self::TOP_CONSUMERS_TO_REPORT).unwrap(),
521            )),
522        }
523    }
524
525    #[inline]
526    fn update_metrics(&self) {
527        QUERY_MEMORY_POOL_USAGE_BYTES.set(self.inner.reserved() as i64);
528    }
529}
530
531impl MemoryPool for MetricsMemoryPool {
532    fn register(&self, consumer: &MemoryConsumer) {
533        self.inner.register(consumer);
534    }
535
536    fn unregister(&self, consumer: &MemoryConsumer) {
537        self.inner.unregister(consumer);
538    }
539
540    fn grow(&self, reservation: &MemoryReservation, additional: usize) {
541        self.inner.grow(reservation, additional);
542        self.update_metrics();
543    }
544
545    fn shrink(&self, reservation: &MemoryReservation, shrink: usize) {
546        self.inner.shrink(reservation, shrink);
547        self.update_metrics();
548    }
549
550    fn try_grow(
551        &self,
552        reservation: &MemoryReservation,
553        additional: usize,
554    ) -> datafusion_common::Result<()> {
555        let result = self.inner.try_grow(reservation, additional);
556        if result.is_err() {
557            QUERY_MEMORY_POOL_REJECTED_TOTAL.inc();
558        }
559        self.update_metrics();
560        result
561    }
562
563    fn reserved(&self) -> usize {
564        self.inner.reserved()
565    }
566
567    fn memory_limit(&self) -> MemoryLimit {
568        self.inner.memory_limit()
569    }
570}