sqlness_runner/
protocol_interceptor.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use sqlness::interceptor::{Interceptor, InterceptorFactory, InterceptorRef};
16use sqlness::SqlnessError;
17
18pub const PROTOCOL_KEY: &str = "protocol";
19pub const POSTGRES: &str = "postgres";
20pub const MYSQL: &str = "mysql";
21pub const PREFIX: &str = "PROTOCOL";
22
23pub struct ProtocolInterceptor {
24    protocol: String,
25}
26
27impl Interceptor for ProtocolInterceptor {
28    fn before_execute(&self, _: &mut Vec<String>, context: &mut sqlness::QueryContext) {
29        context
30            .context
31            .insert(PROTOCOL_KEY.to_string(), self.protocol.clone());
32    }
33}
34
35pub struct ProtocolInterceptorFactory;
36
37impl InterceptorFactory for ProtocolInterceptorFactory {
38    fn try_new(&self, ctx: &str) -> Result<InterceptorRef, SqlnessError> {
39        let protocol = ctx.to_lowercase();
40        match protocol.as_str() {
41            POSTGRES | MYSQL => Ok(Box::new(ProtocolInterceptor { protocol })),
42            _ => Err(SqlnessError::InvalidContext {
43                prefix: PREFIX.to_string(),
44                msg: format!("Unsupported protocol: {}", ctx),
45            }),
46        }
47    }
48}