frontend/
instance.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod builder;
16mod grpc;
17mod influxdb;
18mod jaeger;
19mod log_handler;
20mod logs;
21mod opentsdb;
22mod otlp;
23pub mod prom_store;
24mod promql;
25mod region_query;
26pub mod standalone;
27
28use std::pin::Pin;
29use std::sync::Arc;
30use std::time::{Duration, SystemTime};
31
32use async_stream::stream;
33use async_trait::async_trait;
34use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
35use catalog::process_manager::ProcessManagerRef;
36use catalog::CatalogManagerRef;
37use client::OutputData;
38use common_base::cancellation::CancellableFuture;
39use common_base::Plugins;
40use common_config::KvBackendConfig;
41use common_error::ext::{BoxedError, ErrorExt};
42use common_meta::cache_invalidator::CacheInvalidatorRef;
43use common_meta::ddl::ProcedureExecutorRef;
44use common_meta::key::runtime_switch::RuntimeSwitchManager;
45use common_meta::key::TableMetadataManagerRef;
46use common_meta::kv_backend::KvBackendRef;
47use common_meta::node_manager::NodeManagerRef;
48use common_meta::state_store::KvStateStore;
49use common_procedure::local::{LocalManager, ManagerConfig};
50use common_procedure::options::ProcedureConfig;
51use common_procedure::ProcedureManagerRef;
52use common_query::Output;
53use common_recordbatch::error::StreamTimeoutSnafu;
54use common_recordbatch::RecordBatchStreamWrapper;
55use common_telemetry::{debug, error, info, tracing};
56use datafusion_expr::LogicalPlan;
57use futures::{Stream, StreamExt};
58use log_store::raft_engine::RaftEngineBackend;
59use operator::delete::DeleterRef;
60use operator::insert::InserterRef;
61use operator::statement::{StatementExecutor, StatementExecutorRef};
62use partition::manager::PartitionRuleManagerRef;
63use pipeline::pipeline_operator::PipelineOperator;
64use prometheus::HistogramTimer;
65use promql_parser::label::Matcher;
66use query::metrics::OnDone;
67use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
68use query::query_engine::options::{validate_catalog_and_schema, QueryOptions};
69use query::query_engine::DescribeResult;
70use query::QueryEngineRef;
71use servers::error as server_error;
72use servers::error::{AuthSnafu, ExecuteQuerySnafu, ParsePromQLSnafu};
73use servers::interceptor::{
74    PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef,
75};
76use servers::prometheus_handler::PrometheusHandler;
77use servers::query_handler::sql::SqlQueryHandler;
78use session::context::{Channel, QueryContextRef};
79use session::table_name::table_idents_to_full_name;
80use snafu::prelude::*;
81use sql::dialect::Dialect;
82use sql::parser::{ParseOptions, ParserContext};
83use sql::statements::copy::{CopyDatabase, CopyTable};
84use sql::statements::statement::Statement;
85use sql::statements::tql::Tql;
86use sqlparser::ast::ObjectName;
87pub use standalone::StandaloneDatanodeManager;
88
89use crate::error::{
90    self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu, InvalidSqlSnafu,
91    ParseSqlSnafu, PermissionSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu,
92    StatementTimeoutSnafu, TableOperationSnafu,
93};
94use crate::limiter::LimiterRef;
95use crate::slow_query_recorder::SlowQueryRecorder;
96use crate::stream_wrapper::CancellableStreamWrapper;
97
98/// The frontend instance contains necessary components, and implements many
99/// traits, like [`servers::query_handler::grpc::GrpcQueryHandler`],
100/// [`servers::query_handler::sql::SqlQueryHandler`], etc.
101#[derive(Clone)]
102pub struct Instance {
103    catalog_manager: CatalogManagerRef,
104    pipeline_operator: Arc<PipelineOperator>,
105    statement_executor: Arc<StatementExecutor>,
106    query_engine: QueryEngineRef,
107    plugins: Plugins,
108    inserter: InserterRef,
109    deleter: DeleterRef,
110    table_metadata_manager: TableMetadataManagerRef,
111    slow_query_recorder: Option<SlowQueryRecorder>,
112    limiter: Option<LimiterRef>,
113    process_manager: ProcessManagerRef,
114}
115
116impl Instance {
117    pub async fn try_build_standalone_components(
118        dir: String,
119        kv_backend_config: KvBackendConfig,
120        procedure_config: ProcedureConfig,
121    ) -> Result<(KvBackendRef, ProcedureManagerRef)> {
122        info!(
123            "Creating metadata kvbackend with config: {:?}",
124            kv_backend_config
125        );
126        let kv_backend = RaftEngineBackend::try_open_with_cfg(dir, &kv_backend_config)
127            .map_err(BoxedError::new)
128            .context(error::OpenRaftEngineBackendSnafu)?;
129
130        let kv_backend = Arc::new(kv_backend);
131        let kv_state_store = Arc::new(KvStateStore::new(kv_backend.clone()));
132
133        let manager_config = ManagerConfig {
134            max_retry_times: procedure_config.max_retry_times,
135            retry_delay: procedure_config.retry_delay,
136            max_running_procedures: procedure_config.max_running_procedures,
137            ..Default::default()
138        };
139        let runtime_switch_manager = Arc::new(RuntimeSwitchManager::new(kv_backend.clone()));
140        let procedure_manager = Arc::new(LocalManager::new(
141            manager_config,
142            kv_state_store.clone(),
143            kv_state_store,
144            Some(runtime_switch_manager),
145        ));
146
147        Ok((kv_backend, procedure_manager))
148    }
149
150    pub fn catalog_manager(&self) -> &CatalogManagerRef {
151        &self.catalog_manager
152    }
153
154    pub fn query_engine(&self) -> &QueryEngineRef {
155        &self.query_engine
156    }
157
158    pub fn plugins(&self) -> &Plugins {
159        &self.plugins
160    }
161
162    pub fn statement_executor(&self) -> &StatementExecutorRef {
163        &self.statement_executor
164    }
165
166    pub fn table_metadata_manager(&self) -> &TableMetadataManagerRef {
167        &self.table_metadata_manager
168    }
169
170    pub fn inserter(&self) -> &InserterRef {
171        &self.inserter
172    }
173
174    pub fn process_manager(&self) -> &ProcessManagerRef {
175        &self.process_manager
176    }
177
178    pub fn node_manager(&self) -> &NodeManagerRef {
179        self.inserter.node_manager()
180    }
181
182    pub fn partition_manager(&self) -> &PartitionRuleManagerRef {
183        self.inserter.partition_manager()
184    }
185
186    pub fn cache_invalidator(&self) -> &CacheInvalidatorRef {
187        self.statement_executor.cache_invalidator()
188    }
189
190    pub fn procedure_executor(&self) -> &ProcedureExecutorRef {
191        self.statement_executor.procedure_executor()
192    }
193}
194
195fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
196    ParserContext::create_with_dialect(sql, dialect, ParseOptions::default()).context(ParseSqlSnafu)
197}
198
199impl Instance {
200    async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
201        check_permission(self.plugins.clone(), &stmt, &query_ctx)?;
202
203        let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
204        let query_interceptor = query_interceptor.as_ref();
205
206        let _slow_query_timer = if let Some(recorder) = &self.slow_query_recorder {
207            recorder.start(QueryStatement::Sql(stmt.clone()), query_ctx.clone())
208        } else {
209            None
210        };
211
212        let ticket = self.process_manager.register_query(
213            query_ctx.current_catalog().to_string(),
214            vec![query_ctx.current_schema()],
215            stmt.to_string(),
216            query_ctx.conn_info().to_string(),
217            Some(query_ctx.process_id()),
218        );
219
220        let query_fut = self.exec_statement_with_timeout(stmt, query_ctx, query_interceptor);
221
222        CancellableFuture::new(query_fut, ticket.cancellation_handle.clone())
223            .await
224            .map_err(|_| error::CancelledSnafu.build())?
225            .map(|output| {
226                let Output { meta, data } = output;
227
228                let data = match data {
229                    OutputData::Stream(stream) => {
230                        OutputData::Stream(Box::pin(CancellableStreamWrapper::new(stream, ticket)))
231                    }
232                    other => other,
233                };
234                Output { data, meta }
235            })
236    }
237
238    async fn exec_statement_with_timeout(
239        &self,
240        stmt: Statement,
241        query_ctx: QueryContextRef,
242        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
243    ) -> Result<Output> {
244        let timeout = derive_timeout(&stmt, &query_ctx);
245        match timeout {
246            Some(timeout) => {
247                let start = tokio::time::Instant::now();
248                let output = tokio::time::timeout(
249                    timeout,
250                    self.exec_statement(stmt, query_ctx, query_interceptor),
251                )
252                .await
253                .map_err(|_| StatementTimeoutSnafu.build())??;
254                // compute remaining timeout
255                let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
256                attach_timeout(output, remaining_timeout)
257            }
258            None => {
259                self.exec_statement(stmt, query_ctx, query_interceptor)
260                    .await
261            }
262        }
263    }
264
265    async fn exec_statement(
266        &self,
267        stmt: Statement,
268        query_ctx: QueryContextRef,
269        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
270    ) -> Result<Output> {
271        match stmt {
272            Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => {
273                // TODO: remove this when format is supported in datafusion
274                if let Statement::Explain(explain) = &stmt {
275                    if let Some(format) = explain.format() {
276                        query_ctx.set_explain_format(format.to_string());
277                    }
278                }
279
280                self.plan_and_exec_sql(stmt, &query_ctx, query_interceptor)
281                    .await
282            }
283            Statement::Tql(tql) => {
284                self.plan_and_exec_tql(&query_ctx, query_interceptor, tql)
285                    .await
286            }
287            _ => {
288                query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?;
289                self.statement_executor
290                    .execute_sql(stmt, query_ctx)
291                    .await
292                    .context(TableOperationSnafu)
293            }
294        }
295    }
296
297    async fn plan_and_exec_sql(
298        &self,
299        stmt: Statement,
300        query_ctx: &QueryContextRef,
301        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
302    ) -> Result<Output> {
303        let stmt = QueryStatement::Sql(stmt);
304        let plan = self
305            .statement_executor
306            .plan(&stmt, query_ctx.clone())
307            .await?;
308        let QueryStatement::Sql(stmt) = stmt else {
309            unreachable!()
310        };
311        query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?;
312        self.statement_executor
313            .exec_plan(plan, query_ctx.clone())
314            .await
315            .context(TableOperationSnafu)
316    }
317
318    async fn plan_and_exec_tql(
319        &self,
320        query_ctx: &QueryContextRef,
321        query_interceptor: Option<&SqlQueryInterceptorRef<Error>>,
322        tql: Tql,
323    ) -> Result<Output> {
324        let plan = self
325            .statement_executor
326            .plan_tql(tql.clone(), query_ctx)
327            .await?;
328        query_interceptor.pre_execute(&Statement::Tql(tql), Some(&plan), query_ctx.clone())?;
329        self.statement_executor
330            .exec_plan(plan, query_ctx.clone())
331            .await
332            .context(TableOperationSnafu)
333    }
334}
335
336/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
337/// For MySQL, it applies only to read-only statements.
338fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option<Duration> {
339    let query_timeout = query_ctx.query_timeout()?;
340    if query_timeout.is_zero() {
341        return None;
342    }
343    match query_ctx.channel() {
344        Channel::Mysql if stmt.is_readonly() => Some(query_timeout),
345        Channel::Postgres => Some(query_timeout),
346        _ => None,
347    }
348}
349
350fn attach_timeout(output: Output, mut timeout: Duration) -> Result<Output> {
351    if timeout.is_zero() {
352        return StatementTimeoutSnafu.fail();
353    }
354
355    let output = match output.data {
356        OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
357        OutputData::Stream(mut stream) => {
358            let schema = stream.schema();
359            let s = Box::pin(stream! {
360                let mut start = tokio::time::Instant::now();
361                while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.map_err(|_| StreamTimeoutSnafu.build())? {
362                    yield item;
363
364                    let now = tokio::time::Instant::now();
365                    timeout = timeout.checked_sub(now - start).unwrap_or(Duration::ZERO);
366                    start = now;
367                    // tokio::time::timeout may not return an error immediately when timeout is 0.
368                    if timeout.is_zero() {
369                        StreamTimeoutSnafu.fail()?;
370                    }
371                }
372            }) as Pin<Box<dyn Stream<Item = _> + Send>>;
373            let stream = RecordBatchStreamWrapper {
374                schema,
375                stream: s,
376                output_ordering: None,
377                metrics: Default::default(),
378            };
379            Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
380        }
381    };
382
383    Ok(output)
384}
385
386#[async_trait]
387impl SqlQueryHandler for Instance {
388    type Error = Error;
389
390    #[tracing::instrument(skip_all)]
391    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
392        let query_interceptor_opt = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
393        let query_interceptor = query_interceptor_opt.as_ref();
394        let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
395            Ok(q) => q,
396            Err(e) => return vec![Err(e)],
397        };
398
399        let checker_ref = self.plugins.get::<PermissionCheckerRef>();
400        let checker = checker_ref.as_ref();
401
402        match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
403            .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
404        {
405            Ok(stmts) => {
406                if stmts.is_empty() {
407                    return vec![InvalidSqlSnafu {
408                        err_msg: "empty statements",
409                    }
410                    .fail()];
411                }
412
413                let mut results = Vec::with_capacity(stmts.len());
414                for stmt in stmts {
415                    if let Err(e) = checker
416                        .check_permission(
417                            query_ctx.current_user(),
418                            PermissionReq::SqlStatement(&stmt),
419                        )
420                        .context(PermissionSnafu)
421                    {
422                        results.push(Err(e));
423                        break;
424                    }
425
426                    match self.query_statement(stmt.clone(), query_ctx.clone()).await {
427                        Ok(output) => {
428                            let output_result =
429                                query_interceptor.post_execute(output, query_ctx.clone());
430                            results.push(output_result);
431                        }
432                        Err(e) => {
433                            if e.status_code().should_log_error() {
434                                error!(e; "Failed to execute query: {stmt}");
435                            } else {
436                                debug!("Failed to execute query: {stmt}, {e}");
437                            }
438                            results.push(Err(e));
439                            break;
440                        }
441                    }
442                }
443                results
444            }
445            Err(e) => {
446                vec![Err(e)]
447            }
448        }
449    }
450
451    async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
452        // plan should be prepared before exec
453        // we'll do check there
454        self.query_engine
455            .execute(plan.clone(), query_ctx)
456            .await
457            .context(ExecLogicalPlanSnafu)
458    }
459
460    #[tracing::instrument(skip_all)]
461    async fn do_promql_query(
462        &self,
463        query: &PromQuery,
464        query_ctx: QueryContextRef,
465    ) -> Vec<Result<Output>> {
466        // check will be done in prometheus handler's do_query
467        let result = PrometheusHandler::do_query(self, query, query_ctx)
468            .await
469            .with_context(|_| ExecutePromqlSnafu {
470                query: format!("{query:?}"),
471            });
472        vec![result]
473    }
474
475    async fn do_describe(
476        &self,
477        stmt: Statement,
478        query_ctx: QueryContextRef,
479    ) -> Result<Option<DescribeResult>> {
480        if matches!(
481            stmt,
482            Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_)
483        ) {
484            self.plugins
485                .get::<PermissionCheckerRef>()
486                .as_ref()
487                .check_permission(query_ctx.current_user(), PermissionReq::SqlStatement(&stmt))
488                .context(PermissionSnafu)?;
489
490            let plan = self
491                .query_engine
492                .planner()
493                .plan(&QueryStatement::Sql(stmt), query_ctx.clone())
494                .await
495                .context(PlanStatementSnafu)?;
496            self.query_engine
497                .describe(plan, query_ctx)
498                .await
499                .map(Some)
500                .context(error::DescribeStatementSnafu)
501        } else {
502            Ok(None)
503        }
504    }
505
506    async fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool> {
507        self.catalog_manager
508            .schema_exists(catalog, schema, None)
509            .await
510            .context(error::CatalogSnafu)
511    }
512}
513
514/// Attaches a timer to the output and observes it once the output is exhausted.
515pub fn attach_timer(output: Output, timer: HistogramTimer) -> Output {
516    match output.data {
517        OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
518        OutputData::Stream(stream) => {
519            let stream = OnDone::new(stream, move || {
520                timer.observe_duration();
521            });
522            Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
523        }
524    }
525}
526
527#[async_trait]
528impl PrometheusHandler for Instance {
529    #[tracing::instrument(skip_all)]
530    async fn do_query(
531        &self,
532        query: &PromQuery,
533        query_ctx: QueryContextRef,
534    ) -> server_error::Result<Output> {
535        let interceptor = self
536            .plugins
537            .get::<PromQueryInterceptorRef<server_error::Error>>();
538
539        self.plugins
540            .get::<PermissionCheckerRef>()
541            .as_ref()
542            .check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
543            .context(AuthSnafu)?;
544
545        let stmt = QueryLanguageParser::parse_promql(query, &query_ctx).with_context(|_| {
546            ParsePromQLSnafu {
547                query: query.clone(),
548            }
549        })?;
550
551        let _slow_query_timer = if let Some(recorder) = &self.slow_query_recorder {
552            recorder.start(stmt.clone(), query_ctx.clone())
553        } else {
554            None
555        };
556
557        let plan = self
558            .statement_executor
559            .plan(&stmt, query_ctx.clone())
560            .await
561            .map_err(BoxedError::new)
562            .context(ExecuteQuerySnafu)?;
563
564        interceptor.pre_execute(query, Some(&plan), query_ctx.clone())?;
565
566        let output = self
567            .statement_executor
568            .exec_plan(plan, query_ctx.clone())
569            .await
570            .map_err(BoxedError::new)
571            .context(ExecuteQuerySnafu)?;
572
573        Ok(interceptor.post_execute(output, query_ctx)?)
574    }
575
576    async fn query_metric_names(
577        &self,
578        matchers: Vec<Matcher>,
579        ctx: &QueryContextRef,
580    ) -> server_error::Result<Vec<String>> {
581        self.handle_query_metric_names(matchers, ctx)
582            .await
583            .map_err(BoxedError::new)
584            .context(ExecuteQuerySnafu)
585    }
586
587    async fn query_label_values(
588        &self,
589        metric: String,
590        label_name: String,
591        matchers: Vec<Matcher>,
592        start: SystemTime,
593        end: SystemTime,
594        ctx: &QueryContextRef,
595    ) -> server_error::Result<Vec<String>> {
596        self.handle_query_label_values(metric, label_name, matchers, start, end, ctx)
597            .await
598            .map_err(BoxedError::new)
599            .context(ExecuteQuerySnafu)
600    }
601
602    fn catalog_manager(&self) -> CatalogManagerRef {
603        self.catalog_manager.clone()
604    }
605}
606
607/// Validate `stmt.database` permission if it's presented.
608macro_rules! validate_db_permission {
609    ($stmt: expr, $query_ctx: expr) => {
610        if let Some(database) = &$stmt.database {
611            validate_catalog_and_schema($query_ctx.current_catalog(), database, $query_ctx)
612                .map_err(BoxedError::new)
613                .context(SqlExecInterceptedSnafu)?;
614        }
615    };
616}
617
618pub fn check_permission(
619    plugins: Plugins,
620    stmt: &Statement,
621    query_ctx: &QueryContextRef,
622) -> Result<()> {
623    let need_validate = plugins
624        .get::<QueryOptions>()
625        .map(|opts| opts.disallow_cross_catalog_query)
626        .unwrap_or_default();
627
628    if !need_validate {
629        return Ok(());
630    }
631
632    match stmt {
633        // Will be checked in execution.
634        // TODO(dennis): add a hook for admin commands.
635        Statement::Admin(_) => {}
636        // These are executed by query engine, and will be checked there.
637        Statement::Query(_)
638        | Statement::Explain(_)
639        | Statement::Tql(_)
640        | Statement::Delete(_)
641        | Statement::DeclareCursor(_)
642        | Statement::Copy(sql::statements::copy::Copy::CopyQueryTo(_)) => {}
643        // database ops won't be checked
644        Statement::CreateDatabase(_)
645        | Statement::ShowDatabases(_)
646        | Statement::DropDatabase(_)
647        | Statement::AlterDatabase(_)
648        | Statement::DropFlow(_)
649        | Statement::Use(_) => {}
650        #[cfg(feature = "enterprise")]
651        Statement::DropTrigger(_) => {}
652        Statement::ShowCreateDatabase(stmt) => {
653            validate_database(&stmt.database_name, query_ctx)?;
654        }
655        Statement::ShowCreateTable(stmt) => {
656            validate_param(&stmt.table_name, query_ctx)?;
657        }
658        Statement::ShowCreateFlow(stmt) => {
659            validate_param(&stmt.flow_name, query_ctx)?;
660        }
661        Statement::ShowCreateView(stmt) => {
662            validate_param(&stmt.view_name, query_ctx)?;
663        }
664        Statement::CreateExternalTable(stmt) => {
665            validate_param(&stmt.name, query_ctx)?;
666        }
667        Statement::CreateFlow(stmt) => {
668            // TODO: should also validate source table name here?
669            validate_param(&stmt.sink_table_name, query_ctx)?;
670        }
671        #[cfg(feature = "enterprise")]
672        Statement::CreateTrigger(stmt) => {
673            validate_param(&stmt.trigger_name, query_ctx)?;
674        }
675        Statement::CreateView(stmt) => {
676            validate_param(&stmt.name, query_ctx)?;
677        }
678        Statement::AlterTable(stmt) => {
679            validate_param(stmt.table_name(), query_ctx)?;
680        }
681        // set/show variable now only alter/show variable in session
682        Statement::SetVariables(_) | Statement::ShowVariables(_) => {}
683        // show charset and show collation won't be checked
684        Statement::ShowCharset(_) | Statement::ShowCollation(_) => {}
685
686        Statement::Insert(insert) => {
687            let name = insert.table_name().context(ParseSqlSnafu)?;
688            validate_param(name, query_ctx)?;
689        }
690        Statement::CreateTable(stmt) => {
691            validate_param(&stmt.name, query_ctx)?;
692        }
693        Statement::CreateTableLike(stmt) => {
694            validate_param(&stmt.table_name, query_ctx)?;
695            validate_param(&stmt.source_name, query_ctx)?;
696        }
697        Statement::DropTable(drop_stmt) => {
698            for table_name in drop_stmt.table_names() {
699                validate_param(table_name, query_ctx)?;
700            }
701        }
702        Statement::DropView(stmt) => {
703            validate_param(&stmt.view_name, query_ctx)?;
704        }
705        Statement::ShowTables(stmt) => {
706            validate_db_permission!(stmt, query_ctx);
707        }
708        Statement::ShowTableStatus(stmt) => {
709            validate_db_permission!(stmt, query_ctx);
710        }
711        Statement::ShowColumns(stmt) => {
712            validate_db_permission!(stmt, query_ctx);
713        }
714        Statement::ShowIndex(stmt) => {
715            validate_db_permission!(stmt, query_ctx);
716        }
717        Statement::ShowRegion(stmt) => {
718            validate_db_permission!(stmt, query_ctx);
719        }
720        Statement::ShowViews(stmt) => {
721            validate_db_permission!(stmt, query_ctx);
722        }
723        Statement::ShowFlows(stmt) => {
724            validate_db_permission!(stmt, query_ctx);
725        }
726        #[cfg(feature = "enterprise")]
727        Statement::ShowTriggers(_stmt) => {
728            // The trigger is organized based on the catalog dimension, so there
729            // is no need to check the permission of the database(schema).
730        }
731        Statement::ShowStatus(_stmt) => {}
732        Statement::ShowSearchPath(_stmt) => {}
733        Statement::DescribeTable(stmt) => {
734            validate_param(stmt.name(), query_ctx)?;
735        }
736        Statement::Copy(sql::statements::copy::Copy::CopyTable(stmt)) => match stmt {
737            CopyTable::To(copy_table_to) => validate_param(&copy_table_to.table_name, query_ctx)?,
738            CopyTable::From(copy_table_from) => {
739                validate_param(&copy_table_from.table_name, query_ctx)?
740            }
741        },
742        Statement::Copy(sql::statements::copy::Copy::CopyDatabase(copy_database)) => {
743            match copy_database {
744                CopyDatabase::To(stmt) => validate_database(&stmt.database_name, query_ctx)?,
745                CopyDatabase::From(stmt) => validate_database(&stmt.database_name, query_ctx)?,
746            }
747        }
748        Statement::TruncateTable(stmt) => {
749            validate_param(stmt.table_name(), query_ctx)?;
750        }
751        // cursor operations are always allowed once it's created
752        Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
753        // User can only kill process in their own catalog.
754        Statement::Kill(_) => {}
755        // SHOW PROCESSLIST
756        Statement::ShowProcesslist(_) => {}
757    }
758    Ok(())
759}
760
761fn validate_param(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<()> {
762    let (catalog, schema, _) = table_idents_to_full_name(name, query_ctx)
763        .map_err(BoxedError::new)
764        .context(ExternalSnafu)?;
765
766    validate_catalog_and_schema(&catalog, &schema, query_ctx)
767        .map_err(BoxedError::new)
768        .context(SqlExecInterceptedSnafu)
769}
770
771fn validate_database(name: &ObjectName, query_ctx: &QueryContextRef) -> Result<()> {
772    let (catalog, schema) = match &name.0[..] {
773        [schema] => (
774            query_ctx.current_catalog().to_string(),
775            schema.value.clone(),
776        ),
777        [catalog, schema] => (catalog.value.clone(), schema.value.clone()),
778        _ => InvalidSqlSnafu {
779            err_msg: format!(
780                "expect database name to be <catalog>.<schema> or <schema>, actual: {name}",
781            ),
782        }
783        .fail()?,
784    };
785
786    validate_catalog_and_schema(&catalog, &schema, query_ctx)
787        .map_err(BoxedError::new)
788        .context(SqlExecInterceptedSnafu)
789}
790
791#[cfg(test)]
792mod tests {
793    use std::collections::HashMap;
794
795    use common_base::Plugins;
796    use query::query_engine::options::QueryOptions;
797    use session::context::QueryContext;
798    use sql::dialect::GreptimeDbDialect;
799    use strfmt::Format;
800
801    use super::*;
802
803    #[test]
804    fn test_exec_validation() {
805        let query_ctx = QueryContext::arc();
806        let plugins: Plugins = Plugins::new();
807        plugins.insert(QueryOptions {
808            disallow_cross_catalog_query: true,
809        });
810
811        let sql = r#"
812        SELECT * FROM demo;
813        EXPLAIN SELECT * FROM demo;
814        CREATE DATABASE test_database;
815        SHOW DATABASES;
816        "#;
817        let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
818        assert_eq!(stmts.len(), 4);
819        for stmt in stmts {
820            let re = check_permission(plugins.clone(), &stmt, &query_ctx);
821            re.unwrap();
822        }
823
824        let sql = r#"
825        SHOW CREATE TABLE demo;
826        ALTER TABLE demo ADD COLUMN new_col INT;
827        "#;
828        let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
829        assert_eq!(stmts.len(), 2);
830        for stmt in stmts {
831            let re = check_permission(plugins.clone(), &stmt, &query_ctx);
832            re.unwrap();
833        }
834
835        fn replace_test(template_sql: &str, plugins: Plugins, query_ctx: &QueryContextRef) {
836            // test right
837            let right = vec![("", ""), ("", "public."), ("greptime.", "public.")];
838            for (catalog, schema) in right {
839                let sql = do_fmt(template_sql, catalog, schema);
840                do_test(&sql, plugins.clone(), query_ctx, true);
841            }
842
843            let wrong = vec![
844                ("wrongcatalog.", "public."),
845                ("wrongcatalog.", "wrongschema."),
846            ];
847            for (catalog, schema) in wrong {
848                let sql = do_fmt(template_sql, catalog, schema);
849                do_test(&sql, plugins.clone(), query_ctx, false);
850            }
851        }
852
853        fn do_fmt(template: &str, catalog: &str, schema: &str) -> String {
854            let vars = HashMap::from([
855                ("catalog".to_string(), catalog),
856                ("schema".to_string(), schema),
857            ]);
858            template.format(&vars).unwrap()
859        }
860
861        fn do_test(sql: &str, plugins: Plugins, query_ctx: &QueryContextRef, is_ok: bool) {
862            let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
863            let re = check_permission(plugins, stmt, query_ctx);
864            if is_ok {
865                re.unwrap();
866            } else {
867                assert!(re.is_err());
868            }
869        }
870
871        // test insert
872        let sql = "INSERT INTO {catalog}{schema}monitor(host) VALUES ('host1');";
873        replace_test(sql, plugins.clone(), &query_ctx);
874
875        // test create table
876        let sql = r#"CREATE TABLE {catalog}{schema}demo(
877                            host STRING,
878                            ts TIMESTAMP,
879                            TIME INDEX (ts),
880                            PRIMARY KEY(host)
881                        ) engine=mito;"#;
882        replace_test(sql, plugins.clone(), &query_ctx);
883
884        // test drop table
885        let sql = "DROP TABLE {catalog}{schema}demo;";
886        replace_test(sql, plugins.clone(), &query_ctx);
887
888        // test show tables
889        let sql = "SHOW TABLES FROM public";
890        let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
891        check_permission(plugins.clone(), &stmt[0], &query_ctx).unwrap();
892
893        let sql = "SHOW TABLES FROM private";
894        let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
895        let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
896        assert!(re.is_ok());
897
898        // test describe table
899        let sql = "DESC TABLE {catalog}{schema}demo;";
900        replace_test(sql, plugins, &query_ctx);
901    }
902}