servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::metadata::TableId;
44use tokio::sync::mpsc;
45
46use crate::error::Error::UnsupportedAuthScheme;
47use crate::error::{
48    AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
49    JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
50};
51use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
52use crate::grpc::TonicResult;
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59    handler: ServerGrpcQueryHandlerRef,
60    user_provider: Option<UserProviderRef>,
61    runtime: Option<Runtime>,
62}
63
64impl GreptimeRequestHandler {
65    pub fn new(
66        handler: ServerGrpcQueryHandlerRef,
67        user_provider: Option<UserProviderRef>,
68        runtime: Option<Runtime>,
69    ) -> Self {
70        Self {
71            handler,
72            user_provider,
73            runtime,
74        }
75    }
76
77    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
78    pub(crate) async fn handle_request(
79        &self,
80        request: GreptimeRequest,
81        hints: Vec<(String, String)>,
82    ) -> Result<Output> {
83        let query = request.request.context(InvalidQuerySnafu {
84            reason: "Expecting non-empty GreptimeRequest.",
85        })?;
86
87        let header = request.header.as_ref();
88        let query_ctx = create_query_context(header, hints)?;
89        let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
90        query_ctx.set_current_user(user_info);
91
92        let handler = self.handler.clone();
93        let request_type = request_type(&query).to_string();
94        let db = query_ctx.get_db_string();
95        let timer = RequestTimer::new(db.clone(), request_type);
96        let tracing_context = TracingContext::from_current_span();
97
98        let result_future = async move {
99            handler
100                .do_query(query, query_ctx)
101                .trace(tracing_context.attach(tracing::info_span!(
102                    "GreptimeRequestHandler::handle_request_runtime"
103                )))
104                .await
105                .map_err(|e| {
106                    if e.status_code().should_log_error() {
107                        let root_error = e.root_cause().unwrap_or(&e);
108                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
109                    } else {
110                        // Currently, we still print a debug log.
111                        debug!("Failed to handle request, err: {:?}", e);
112                    }
113                    e
114                })
115        };
116
117        match &self.runtime {
118            Some(runtime) => {
119                // Executes requests in another runtime to
120                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
121                //   - Refer to our blog for the rational behind it:
122                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
123                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
124                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
125                // 2. avoid the handler blocks the gRPC runtime incidentally.
126                runtime
127                    .spawn(result_future)
128                    .await
129                    .context(JoinTaskSnafu)
130                    .inspect_err(|e| {
131                        timer.record(e.status_code());
132                    })?
133            }
134            None => result_future.await,
135        }
136    }
137
138    pub(crate) async fn put_record_batches(
139        &self,
140        mut stream: PutRecordBatchRequestStream,
141        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
142    ) {
143        let handler = self.handler.clone();
144        let runtime = self
145            .runtime
146            .clone()
147            .unwrap_or_else(common_runtime::global_runtime);
148        runtime.spawn(async move {
149            // Cached table id
150            let mut table_id: Option<TableId> = None;
151
152            let mut decoder = FlightDecoder::default();
153            while let Some(request) = stream.next().await {
154                let request = match request {
155                    Ok(request) => request,
156                    Err(e) => {
157                        let _ = result_sender.try_send(Err(e));
158                        break;
159                    }
160                };
161                let PutRecordBatchRequest {
162                    table_name,
163                    request_id,
164                    data,
165                } = request;
166
167                let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
168                let result = handler
169                    .put_record_batch(&table_name, &mut table_id, &mut decoder, data)
170                    .await
171                    .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
172                timer.observe_duration();
173                let result = result
174                    .map(|x| DoPutResponse::new(request_id, x))
175                    .map_err(Into::into);
176                if result_sender.try_send(result).is_err() {
177                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
178                    break;
179                }
180            }
181        });
182    }
183
184    pub(crate) async fn validate_auth(
185        &self,
186        username_and_password: Option<&str>,
187        db: Option<&str>,
188    ) -> Result<bool> {
189        if self.user_provider.is_none() {
190            return Ok(true);
191        }
192
193        let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
194        let username_and_password = BASE64_STANDARD
195            .decode(username_and_password)
196            .context(InvalidBase64ValueSnafu)
197            .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
198
199        let mut split = username_and_password.splitn(2, ':');
200        let (username, password) = match (split.next(), split.next()) {
201            (Some(username), Some(password)) => (username, password),
202            (Some(username), None) => (username, ""),
203            (None, None) => return Ok(false),
204            _ => unreachable!(), // because this iterator won't yield Some after None
205        };
206
207        let (catalog, schema) = if let Some(db) = db {
208            parse_catalog_and_schema_from_db_string(db)
209        } else {
210            (
211                DEFAULT_CATALOG_NAME.to_string(),
212                DEFAULT_SCHEMA_NAME.to_string(),
213            )
214        };
215        let header = RequestHeader {
216            authorization: Some(AuthHeader {
217                auth_scheme: Some(AuthScheme::Basic(Basic {
218                    username: username.to_string(),
219                    password: password.to_string(),
220                })),
221            }),
222            catalog,
223            schema,
224            ..Default::default()
225        };
226
227        Ok(auth(
228            self.user_provider.clone(),
229            Some(&header),
230            &QueryContext::arc(),
231        )
232        .await
233        .is_ok())
234    }
235}
236
237pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
238    request
239        .request
240        .as_ref()
241        .map(request_type)
242        .unwrap_or_default()
243}
244
245pub(crate) async fn auth(
246    user_provider: Option<UserProviderRef>,
247    header: Option<&RequestHeader>,
248    query_ctx: &QueryContextRef,
249) -> Result<UserInfoRef> {
250    let Some(user_provider) = user_provider else {
251        return Ok(auth::userinfo_by_name(None));
252    };
253
254    let auth_scheme = header
255        .and_then(|header| {
256            header
257                .authorization
258                .as_ref()
259                .and_then(|x| x.auth_scheme.clone())
260        })
261        .context(NotFoundAuthHeaderSnafu)?;
262
263    match auth_scheme {
264        AuthScheme::Basic(Basic { username, password }) => user_provider
265            .auth(
266                Identity::UserId(&username, None),
267                Password::PlainText(password.into()),
268                query_ctx.current_catalog(),
269                &query_ctx.current_schema(),
270            )
271            .await
272            .context(AuthSnafu),
273        AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
274            name: "Token AuthScheme".to_string(),
275        }),
276    }
277    .inspect_err(|e| {
278        METRIC_AUTH_FAILURE
279            .with_label_values(&[e.status_code().as_ref()])
280            .inc();
281    })
282}
283
284pub(crate) fn create_query_context(
285    header: Option<&RequestHeader>,
286    mut extensions: Vec<(String, String)>,
287) -> Result<QueryContextRef> {
288    let (catalog, schema) = header
289        .map(|header| {
290            // We provide dbname field in newer versions of protos/sdks
291            // parse dbname from header in priority
292            if !header.dbname.is_empty() {
293                parse_catalog_and_schema_from_db_string(&header.dbname)
294            } else {
295                (
296                    if !header.catalog.is_empty() {
297                        header.catalog.to_lowercase()
298                    } else {
299                        DEFAULT_CATALOG_NAME.to_string()
300                    },
301                    if !header.schema.is_empty() {
302                        header.schema.to_lowercase()
303                    } else {
304                        DEFAULT_SCHEMA_NAME.to_string()
305                    },
306                )
307            }
308        })
309        .unwrap_or_else(|| {
310            (
311                DEFAULT_CATALOG_NAME.to_string(),
312                DEFAULT_SCHEMA_NAME.to_string(),
313            )
314        });
315    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
316    let mut ctx_builder = QueryContextBuilder::default()
317        .current_catalog(catalog)
318        .current_schema(schema)
319        .timezone(timezone);
320
321    if let Some(x) = extensions
322        .iter()
323        .position(|(k, _)| k == READ_PREFERENCE_HINT)
324    {
325        let (k, v) = extensions.swap_remove(x);
326        let Ok(read_preference) = ReadPreference::from_str(&v) else {
327            return UnknownHintSnafu {
328                hint: format!("{k}={v}"),
329            }
330            .fail();
331        };
332        ctx_builder = ctx_builder.read_preference(read_preference);
333    }
334
335    for (key, value) in extensions {
336        ctx_builder = ctx_builder.set_extension(key, value);
337    }
338    Ok(ctx_builder.build().into())
339}
340
341/// Histogram timer for handling gRPC request.
342///
343/// The timer records the elapsed time with [StatusCode::Success] on drop.
344pub(crate) struct RequestTimer {
345    start: Instant,
346    db: String,
347    request_type: String,
348    status_code: StatusCode,
349}
350
351impl RequestTimer {
352    /// Returns a new timer.
353    pub fn new(db: String, request_type: String) -> RequestTimer {
354        RequestTimer {
355            start: Instant::now(),
356            db,
357            request_type,
358            status_code: StatusCode::Success,
359        }
360    }
361
362    /// Consumes the timer and record the elapsed time with specific `status_code`.
363    pub fn record(mut self, status_code: StatusCode) {
364        self.status_code = status_code;
365    }
366}
367
368impl Drop for RequestTimer {
369    fn drop(&mut self) {
370        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
371            .with_label_values(&[
372                self.db.as_str(),
373                self.request_type.as_str(),
374                self.status_code.as_ref(),
375            ])
376            .observe(self.start.elapsed().as_secs_f64());
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use chrono::FixedOffset;
383    use common_time::Timezone;
384
385    use super::*;
386
387    #[test]
388    fn test_create_query_context() {
389        let header = RequestHeader {
390            catalog: "cat-a-log".to_string(),
391            timezone: "+01:00".to_string(),
392            ..Default::default()
393        };
394        let query_context = create_query_context(
395            Some(&header),
396            vec![
397                ("auto_create_table".to_string(), "true".to_string()),
398                ("read_preference".to_string(), "leader".to_string()),
399            ],
400        )
401        .unwrap();
402        assert_eq!(query_context.current_catalog(), "cat-a-log");
403        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
404        assert_eq!(
405            query_context.timezone(),
406            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
407        );
408        assert!(matches!(
409            query_context.read_preference(),
410            ReadPreference::Leader
411        ));
412        assert_eq!(
413            query_context.extensions().into_iter().collect::<Vec<_>>(),
414            vec![("auto_create_table".to_string(), "true".to_string())]
415        );
416    }
417}