query/datafusion/planner/
function_alias.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16
17use once_cell::sync::Lazy;
18
19const SCALAR_ALIASES: &[(&str, &str)] = &[
20    // SQL compat aliases.
21    ("ucase", "upper"),
22    ("lcase", "lower"),
23    ("ceiling", "ceil"),
24    ("mid", "substr"),
25    // MySQL's RAND([seed]) accepts an optional seed argument, while DataFusion's `random()`
26    // does not. We alias the name for `rand()` compatibility, and `rand(seed)` will error
27    // due to mismatched arity.
28    ("rand", "random"),
29];
30
31const AGGREGATE_ALIASES: &[(&str, &str)] = &[
32    // MySQL compat aliases that don't override existing DataFusion aggregate names.
33    //
34    // NOTE: We intentionally do NOT alias `stddev` here, because DataFusion defines `stddev`
35    // as sample standard deviation while MySQL's `STDDEV` is population standard deviation.
36    ("std", "stddev_pop"),
37    ("variance", "var_pop"),
38];
39
40static SCALAR_FUNCTION_ALIAS: Lazy<HashMap<&'static str, &'static str>> =
41    Lazy::new(|| SCALAR_ALIASES.iter().copied().collect());
42
43static AGGREGATE_FUNCTION_ALIAS: Lazy<HashMap<&'static str, &'static str>> =
44    Lazy::new(|| AGGREGATE_ALIASES.iter().copied().collect());
45
46pub fn resolve_scalar(name: &str) -> Option<&'static str> {
47    let name = name.to_ascii_lowercase();
48    SCALAR_FUNCTION_ALIAS.get(name.as_str()).copied()
49}
50
51pub fn resolve_aggregate(name: &str) -> Option<&'static str> {
52    let name = name.to_ascii_lowercase();
53    AGGREGATE_FUNCTION_ALIAS.get(name.as_str()).copied()
54}
55
56pub fn scalar_alias_names() -> impl Iterator<Item = &'static str> {
57    SCALAR_ALIASES.iter().map(|(name, _)| *name)
58}
59
60pub fn aggregate_alias_names() -> impl Iterator<Item = &'static str> {
61    AGGREGATE_ALIASES.iter().map(|(name, _)| *name)
62}
63
64#[cfg(test)]
65mod tests {
66    use super::{resolve_aggregate, resolve_scalar};
67
68    #[test]
69    fn resolves_scalar_aliases_case_insensitive() {
70        assert_eq!(resolve_scalar("ucase"), Some("upper"));
71        assert_eq!(resolve_scalar("UCASE"), Some("upper"));
72        assert_eq!(resolve_scalar("lcase"), Some("lower"));
73        assert_eq!(resolve_scalar("ceiling"), Some("ceil"));
74        assert_eq!(resolve_scalar("MID"), Some("substr"));
75        assert_eq!(resolve_scalar("RAND"), Some("random"));
76        assert_eq!(resolve_scalar("not_a_real_alias"), None);
77    }
78
79    #[test]
80    fn resolves_aggregate_aliases_case_insensitive() {
81        assert_eq!(resolve_aggregate("std"), Some("stddev_pop"));
82        assert_eq!(resolve_aggregate("variance"), Some("var_pop"));
83        assert_eq!(resolve_aggregate("STDDEV"), None);
84        assert_eq!(resolve_aggregate("not_a_real_alias"), None);
85    }
86}