servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{Channel, QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::TableRef;
44use tokio::sync::mpsc;
45use tokio::sync::mpsc::error::TrySendError;
46
47use crate::error::Error::UnsupportedAuthScheme;
48use crate::error::{
49    AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
50    JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
51};
52use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
53use crate::grpc::{FlightCompression, TonicResult};
54use crate::metrics;
55use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
56use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
57
58#[derive(Clone)]
59pub struct GreptimeRequestHandler {
60    handler: ServerGrpcQueryHandlerRef,
61    user_provider: Option<UserProviderRef>,
62    runtime: Option<Runtime>,
63    pub(crate) flight_compression: FlightCompression,
64}
65
66impl GreptimeRequestHandler {
67    pub fn new(
68        handler: ServerGrpcQueryHandlerRef,
69        user_provider: Option<UserProviderRef>,
70        runtime: Option<Runtime>,
71        flight_compression: FlightCompression,
72    ) -> Self {
73        Self {
74            handler,
75            user_provider,
76            runtime,
77            flight_compression,
78        }
79    }
80
81    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
82    pub(crate) async fn handle_request(
83        &self,
84        request: GreptimeRequest,
85        hints: Vec<(String, String)>,
86    ) -> Result<Output> {
87        let query = request.request.context(InvalidQuerySnafu {
88            reason: "Expecting non-empty GreptimeRequest.",
89        })?;
90
91        let header = request.header.as_ref();
92        let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
93        let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
94        query_ctx.set_current_user(user_info);
95
96        let handler = self.handler.clone();
97        let request_type = request_type(&query).to_string();
98        let db = query_ctx.get_db_string();
99        let timer = RequestTimer::new(db.clone(), request_type);
100        let tracing_context = TracingContext::from_current_span();
101
102        let result_future = async move {
103            handler
104                .do_query(query, query_ctx)
105                .trace(tracing_context.attach(tracing::info_span!(
106                    "GreptimeRequestHandler::handle_request_runtime"
107                )))
108                .await
109                .map_err(|e| {
110                    if e.status_code().should_log_error() {
111                        let root_error = e.root_cause().unwrap_or(&e);
112                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
113                    } else {
114                        // Currently, we still print a debug log.
115                        debug!("Failed to handle request, err: {:?}", e);
116                    }
117                    e
118                })
119        };
120
121        match &self.runtime {
122            Some(runtime) => {
123                // Executes requests in another runtime to
124                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
125                //   - Refer to our blog for the rational behind it:
126                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
127                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
128                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
129                // 2. avoid the handler blocks the gRPC runtime incidentally.
130                runtime
131                    .spawn(result_future)
132                    .await
133                    .context(JoinTaskSnafu)
134                    .inspect_err(|e| {
135                        timer.record(e.status_code());
136                    })?
137            }
138            None => result_future.await,
139        }
140    }
141
142    pub(crate) async fn put_record_batches(
143        &self,
144        mut stream: PutRecordBatchRequestStream,
145        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
146    ) {
147        let handler = self.handler.clone();
148        let runtime = self
149            .runtime
150            .clone()
151            .unwrap_or_else(common_runtime::global_runtime);
152        runtime.spawn(async move {
153            // Cached table ref
154            let mut table_ref: Option<TableRef> = None;
155
156            let mut decoder = FlightDecoder::default();
157            while let Some(request) = stream.next().await {
158                let request = match request {
159                    Ok(request) => request,
160                    Err(e) => {
161                        let _ = result_sender.try_send(Err(e));
162                        break;
163                    }
164                };
165                let PutRecordBatchRequest {
166                    table_name,
167                    request_id,
168                    data,
169                } = request;
170
171                let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
172                let result = handler
173                    .put_record_batch(&table_name, &mut table_ref, &mut decoder, data)
174                    .await
175                    .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
176                timer.observe_duration();
177                let result = result
178                    .map(|x| DoPutResponse::new(request_id, x))
179                    .map_err(Into::into);
180                if let Err(e)= result_sender.try_send(result)
181                    && let TrySendError::Closed(_) = e {
182                    warn!(r#""DoPut" client with request_id {} maybe unreachable, abort handling its message"#, request_id);
183                    break;
184                }
185            }
186        });
187    }
188
189    pub(crate) async fn validate_auth(
190        &self,
191        username_and_password: Option<&str>,
192        db: Option<&str>,
193    ) -> Result<bool> {
194        if self.user_provider.is_none() {
195            return Ok(true);
196        }
197
198        let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
199        let username_and_password = BASE64_STANDARD
200            .decode(username_and_password)
201            .context(InvalidBase64ValueSnafu)
202            .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
203
204        let mut split = username_and_password.splitn(2, ':');
205        let (username, password) = match (split.next(), split.next()) {
206            (Some(username), Some(password)) => (username, password),
207            (Some(username), None) => (username, ""),
208            (None, None) => return Ok(false),
209            _ => unreachable!(), // because this iterator won't yield Some after None
210        };
211
212        let (catalog, schema) = if let Some(db) = db {
213            parse_catalog_and_schema_from_db_string(db)
214        } else {
215            (
216                DEFAULT_CATALOG_NAME.to_string(),
217                DEFAULT_SCHEMA_NAME.to_string(),
218            )
219        };
220        let header = RequestHeader {
221            authorization: Some(AuthHeader {
222                auth_scheme: Some(AuthScheme::Basic(Basic {
223                    username: username.to_string(),
224                    password: password.to_string(),
225                })),
226            }),
227            catalog,
228            schema,
229            ..Default::default()
230        };
231
232        Ok(auth(
233            self.user_provider.clone(),
234            Some(&header),
235            &QueryContext::arc(),
236        )
237        .await
238        .is_ok())
239    }
240}
241
242pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
243    request
244        .request
245        .as_ref()
246        .map(request_type)
247        .unwrap_or_default()
248}
249
250pub(crate) async fn auth(
251    user_provider: Option<UserProviderRef>,
252    header: Option<&RequestHeader>,
253    query_ctx: &QueryContextRef,
254) -> Result<UserInfoRef> {
255    let Some(user_provider) = user_provider else {
256        return Ok(auth::userinfo_by_name(None));
257    };
258
259    let auth_scheme = header
260        .and_then(|header| {
261            header
262                .authorization
263                .as_ref()
264                .and_then(|x| x.auth_scheme.clone())
265        })
266        .context(NotFoundAuthHeaderSnafu)?;
267
268    match auth_scheme {
269        AuthScheme::Basic(Basic { username, password }) => user_provider
270            .auth(
271                Identity::UserId(&username, None),
272                Password::PlainText(password.into()),
273                query_ctx.current_catalog(),
274                &query_ctx.current_schema(),
275            )
276            .await
277            .context(AuthSnafu),
278        AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
279            name: "Token AuthScheme".to_string(),
280        }),
281    }
282    .inspect_err(|e| {
283        METRIC_AUTH_FAILURE
284            .with_label_values(&[e.status_code().as_ref()])
285            .inc();
286    })
287}
288
289/// Creates a new `QueryContext` from the provided request header and extensions.
290/// Strongly recommend setting an appropriate channel, as this is very helpful for statistics.
291pub(crate) fn create_query_context(
292    channel: Channel,
293    header: Option<&RequestHeader>,
294    mut extensions: Vec<(String, String)>,
295) -> Result<QueryContextRef> {
296    let (catalog, schema) = header
297        .map(|header| {
298            // We provide dbname field in newer versions of protos/sdks
299            // parse dbname from header in priority
300            if !header.dbname.is_empty() {
301                parse_catalog_and_schema_from_db_string(&header.dbname)
302            } else {
303                (
304                    if !header.catalog.is_empty() {
305                        header.catalog.to_lowercase()
306                    } else {
307                        DEFAULT_CATALOG_NAME.to_string()
308                    },
309                    if !header.schema.is_empty() {
310                        header.schema.to_lowercase()
311                    } else {
312                        DEFAULT_SCHEMA_NAME.to_string()
313                    },
314                )
315            }
316        })
317        .unwrap_or_else(|| {
318            (
319                DEFAULT_CATALOG_NAME.to_string(),
320                DEFAULT_SCHEMA_NAME.to_string(),
321            )
322        });
323    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
324    let mut ctx_builder = QueryContextBuilder::default()
325        .current_catalog(catalog)
326        .current_schema(schema)
327        .timezone(timezone)
328        .channel(channel);
329
330    if let Some(x) = extensions
331        .iter()
332        .position(|(k, _)| k == READ_PREFERENCE_HINT)
333    {
334        let (k, v) = extensions.swap_remove(x);
335        let Ok(read_preference) = ReadPreference::from_str(&v) else {
336            return UnknownHintSnafu {
337                hint: format!("{k}={v}"),
338            }
339            .fail();
340        };
341        ctx_builder = ctx_builder.read_preference(read_preference);
342    }
343
344    for (key, value) in extensions {
345        ctx_builder = ctx_builder.set_extension(key, value);
346    }
347    Ok(ctx_builder.build().into())
348}
349
350/// Histogram timer for handling gRPC request.
351///
352/// The timer records the elapsed time with [StatusCode::Success] on drop.
353pub(crate) struct RequestTimer {
354    start: Instant,
355    db: String,
356    request_type: String,
357    status_code: StatusCode,
358}
359
360impl RequestTimer {
361    /// Returns a new timer.
362    pub fn new(db: String, request_type: String) -> RequestTimer {
363        RequestTimer {
364            start: Instant::now(),
365            db,
366            request_type,
367            status_code: StatusCode::Success,
368        }
369    }
370
371    /// Consumes the timer and record the elapsed time with specific `status_code`.
372    pub fn record(mut self, status_code: StatusCode) {
373        self.status_code = status_code;
374    }
375}
376
377impl Drop for RequestTimer {
378    fn drop(&mut self) {
379        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
380            .with_label_values(&[
381                self.db.as_str(),
382                self.request_type.as_str(),
383                self.status_code.as_ref(),
384            ])
385            .observe(self.start.elapsed().as_secs_f64());
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use chrono::FixedOffset;
392    use common_time::Timezone;
393
394    use super::*;
395
396    #[test]
397    fn test_create_query_context() {
398        let header = RequestHeader {
399            catalog: "cat-a-log".to_string(),
400            timezone: "+01:00".to_string(),
401            ..Default::default()
402        };
403        let query_context = create_query_context(
404            Channel::Unknown,
405            Some(&header),
406            vec![
407                ("auto_create_table".to_string(), "true".to_string()),
408                ("read_preference".to_string(), "leader".to_string()),
409            ],
410        )
411        .unwrap();
412        assert_eq!(query_context.current_catalog(), "cat-a-log");
413        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
414        assert_eq!(
415            query_context.timezone(),
416            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
417        );
418        assert!(matches!(
419            query_context.read_preference(),
420            ReadPreference::Leader
421        ));
422        assert_eq!(
423            query_context.extensions().into_iter().collect::<Vec<_>>(),
424            vec![("auto_create_table".to_string(), "true".to_string())]
425        );
426    }
427}