servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use snafu::{OptionExt, ResultExt};
42use table::metadata::TableId;
43use tokio::sync::mpsc;
44
45use crate::error::Error::UnsupportedAuthScheme;
46use crate::error::{
47    AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
48    JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
49};
50use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
51use crate::grpc::TonicResult;
52use crate::hint_headers::READ_PREFERENCE_HINT;
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59    handler: ServerGrpcQueryHandlerRef,
60    user_provider: Option<UserProviderRef>,
61    runtime: Option<Runtime>,
62}
63
64impl GreptimeRequestHandler {
65    pub fn new(
66        handler: ServerGrpcQueryHandlerRef,
67        user_provider: Option<UserProviderRef>,
68        runtime: Option<Runtime>,
69    ) -> Self {
70        Self {
71            handler,
72            user_provider,
73            runtime,
74        }
75    }
76
77    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
78    pub(crate) async fn handle_request(
79        &self,
80        request: GreptimeRequest,
81        hints: Vec<(String, String)>,
82    ) -> Result<Output> {
83        let query = request.request.context(InvalidQuerySnafu {
84            reason: "Expecting non-empty GreptimeRequest.",
85        })?;
86
87        let header = request.header.as_ref();
88        let query_ctx = create_query_context(header, hints)?;
89        let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
90        query_ctx.set_current_user(user_info);
91
92        let handler = self.handler.clone();
93        let request_type = request_type(&query).to_string();
94        let db = query_ctx.get_db_string();
95        let timer = RequestTimer::new(db.clone(), request_type);
96        let tracing_context = TracingContext::from_current_span();
97
98        let result_future = async move {
99            handler
100                .do_query(query, query_ctx)
101                .trace(tracing_context.attach(tracing::info_span!(
102                    "GreptimeRequestHandler::handle_request_runtime"
103                )))
104                .await
105                .map_err(|e| {
106                    if e.status_code().should_log_error() {
107                        let root_error = e.root_cause().unwrap_or(&e);
108                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
109                    } else {
110                        // Currently, we still print a debug log.
111                        debug!("Failed to handle request, err: {:?}", e);
112                    }
113                    e
114                })
115        };
116
117        match &self.runtime {
118            Some(runtime) => {
119                // Executes requests in another runtime to
120                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
121                //   - Refer to our blog for the rational behind it:
122                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
123                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
124                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
125                // 2. avoid the handler blocks the gRPC runtime incidentally.
126                runtime
127                    .spawn(result_future)
128                    .await
129                    .context(JoinTaskSnafu)
130                    .inspect_err(|e| {
131                        timer.record(e.status_code());
132                    })?
133            }
134            None => result_future.await,
135        }
136    }
137
138    pub(crate) async fn put_record_batches(
139        &self,
140        mut stream: PutRecordBatchRequestStream,
141        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
142    ) {
143        let handler = self.handler.clone();
144        let runtime = self
145            .runtime
146            .clone()
147            .unwrap_or_else(common_runtime::global_runtime);
148        runtime.spawn(async move {
149            // Cached table id
150            let mut table_id: Option<TableId> = None;
151
152            let mut decoder = FlightDecoder::default();
153            while let Some(request) = stream.next().await {
154                let request = match request {
155                    Ok(request) => request,
156                    Err(e) => {
157                        let _ = result_sender.try_send(Err(e));
158                        break;
159                    }
160                };
161                let PutRecordBatchRequest {
162                    table_name,
163                    request_id,
164                    data,
165                } = request;
166
167                let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
168                let result = handler
169                    .put_record_batch(&table_name, &mut table_id, &mut decoder, data)
170                    .await;
171                timer.observe_duration();
172                let result = result
173                    .map(|x| DoPutResponse::new(request_id, x))
174                    .map_err(Into::into);
175                if result_sender.try_send(result).is_err() {
176                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
177                    break;
178                }
179            }
180        });
181    }
182
183    pub(crate) async fn validate_auth(
184        &self,
185        username_and_password: Option<&str>,
186        db: Option<&str>,
187    ) -> Result<bool> {
188        if self.user_provider.is_none() {
189            return Ok(true);
190        }
191
192        let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
193        let username_and_password = BASE64_STANDARD
194            .decode(username_and_password)
195            .context(InvalidBase64ValueSnafu)
196            .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
197
198        let mut split = username_and_password.splitn(2, ':');
199        let (username, password) = match (split.next(), split.next()) {
200            (Some(username), Some(password)) => (username, password),
201            (Some(username), None) => (username, ""),
202            (None, None) => return Ok(false),
203            _ => unreachable!(), // because this iterator won't yield Some after None
204        };
205
206        let (catalog, schema) = if let Some(db) = db {
207            parse_catalog_and_schema_from_db_string(db)
208        } else {
209            (
210                DEFAULT_CATALOG_NAME.to_string(),
211                DEFAULT_SCHEMA_NAME.to_string(),
212            )
213        };
214        let header = RequestHeader {
215            authorization: Some(AuthHeader {
216                auth_scheme: Some(AuthScheme::Basic(Basic {
217                    username: username.to_string(),
218                    password: password.to_string(),
219                })),
220            }),
221            catalog,
222            schema,
223            ..Default::default()
224        };
225
226        Ok(auth(
227            self.user_provider.clone(),
228            Some(&header),
229            &QueryContext::arc(),
230        )
231        .await
232        .is_ok())
233    }
234}
235
236pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
237    request
238        .request
239        .as_ref()
240        .map(request_type)
241        .unwrap_or_default()
242}
243
244pub(crate) async fn auth(
245    user_provider: Option<UserProviderRef>,
246    header: Option<&RequestHeader>,
247    query_ctx: &QueryContextRef,
248) -> Result<UserInfoRef> {
249    let Some(user_provider) = user_provider else {
250        return Ok(auth::userinfo_by_name(None));
251    };
252
253    let auth_scheme = header
254        .and_then(|header| {
255            header
256                .authorization
257                .as_ref()
258                .and_then(|x| x.auth_scheme.clone())
259        })
260        .context(NotFoundAuthHeaderSnafu)?;
261
262    match auth_scheme {
263        AuthScheme::Basic(Basic { username, password }) => user_provider
264            .auth(
265                Identity::UserId(&username, None),
266                Password::PlainText(password.into()),
267                query_ctx.current_catalog(),
268                &query_ctx.current_schema(),
269            )
270            .await
271            .context(AuthSnafu),
272        AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
273            name: "Token AuthScheme".to_string(),
274        }),
275    }
276    .inspect_err(|e| {
277        METRIC_AUTH_FAILURE
278            .with_label_values(&[e.status_code().as_ref()])
279            .inc();
280    })
281}
282
283pub(crate) fn create_query_context(
284    header: Option<&RequestHeader>,
285    mut extensions: Vec<(String, String)>,
286) -> Result<QueryContextRef> {
287    let (catalog, schema) = header
288        .map(|header| {
289            // We provide dbname field in newer versions of protos/sdks
290            // parse dbname from header in priority
291            if !header.dbname.is_empty() {
292                parse_catalog_and_schema_from_db_string(&header.dbname)
293            } else {
294                (
295                    if !header.catalog.is_empty() {
296                        header.catalog.to_lowercase()
297                    } else {
298                        DEFAULT_CATALOG_NAME.to_string()
299                    },
300                    if !header.schema.is_empty() {
301                        header.schema.to_lowercase()
302                    } else {
303                        DEFAULT_SCHEMA_NAME.to_string()
304                    },
305                )
306            }
307        })
308        .unwrap_or_else(|| {
309            (
310                DEFAULT_CATALOG_NAME.to_string(),
311                DEFAULT_SCHEMA_NAME.to_string(),
312            )
313        });
314    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
315    let mut ctx_builder = QueryContextBuilder::default()
316        .current_catalog(catalog)
317        .current_schema(schema)
318        .timezone(timezone);
319
320    if let Some(x) = extensions
321        .iter()
322        .position(|(k, _)| k == READ_PREFERENCE_HINT)
323    {
324        let (k, v) = extensions.swap_remove(x);
325        let Ok(read_preference) = ReadPreference::from_str(&v) else {
326            return UnknownHintSnafu {
327                hint: format!("{k}={v}"),
328            }
329            .fail();
330        };
331        ctx_builder = ctx_builder.read_preference(read_preference);
332    }
333
334    for (key, value) in extensions {
335        ctx_builder = ctx_builder.set_extension(key, value);
336    }
337    Ok(ctx_builder.build().into())
338}
339
340/// Histogram timer for handling gRPC request.
341///
342/// The timer records the elapsed time with [StatusCode::Success] on drop.
343pub(crate) struct RequestTimer {
344    start: Instant,
345    db: String,
346    request_type: String,
347    status_code: StatusCode,
348}
349
350impl RequestTimer {
351    /// Returns a new timer.
352    pub fn new(db: String, request_type: String) -> RequestTimer {
353        RequestTimer {
354            start: Instant::now(),
355            db,
356            request_type,
357            status_code: StatusCode::Success,
358        }
359    }
360
361    /// Consumes the timer and record the elapsed time with specific `status_code`.
362    pub fn record(mut self, status_code: StatusCode) {
363        self.status_code = status_code;
364    }
365}
366
367impl Drop for RequestTimer {
368    fn drop(&mut self) {
369        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
370            .with_label_values(&[
371                self.db.as_str(),
372                self.request_type.as_str(),
373                self.status_code.as_ref(),
374            ])
375            .observe(self.start.elapsed().as_secs_f64());
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use chrono::FixedOffset;
382    use common_time::Timezone;
383
384    use super::*;
385
386    #[test]
387    fn test_create_query_context() {
388        let header = RequestHeader {
389            catalog: "cat-a-log".to_string(),
390            timezone: "+01:00".to_string(),
391            ..Default::default()
392        };
393        let query_context = create_query_context(
394            Some(&header),
395            vec![
396                ("auto_create_table".to_string(), "true".to_string()),
397                ("read_preference".to_string(), "leader".to_string()),
398            ],
399        )
400        .unwrap();
401        assert_eq!(query_context.current_catalog(), "cat-a-log");
402        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
403        assert_eq!(
404            query_context.timezone(),
405            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
406        );
407        assert!(matches!(
408            query_context.read_preference(),
409            ReadPreference::Leader
410        ));
411        assert_eq!(
412            query_context.extensions().into_iter().collect::<Vec<_>>(),
413            vec![("auto_create_table".to_string(), "true".to_string())]
414        );
415    }
416}