servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::TableRef;
44use tokio::sync::mpsc;
45
46use crate::error::Error::UnsupportedAuthScheme;
47use crate::error::{
48    AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
49    JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
50};
51use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
52use crate::grpc::{FlightCompression, TonicResult};
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59    handler: ServerGrpcQueryHandlerRef,
60    user_provider: Option<UserProviderRef>,
61    runtime: Option<Runtime>,
62    pub(crate) flight_compression: FlightCompression,
63}
64
65impl GreptimeRequestHandler {
66    pub fn new(
67        handler: ServerGrpcQueryHandlerRef,
68        user_provider: Option<UserProviderRef>,
69        runtime: Option<Runtime>,
70        flight_compression: FlightCompression,
71    ) -> Self {
72        Self {
73            handler,
74            user_provider,
75            runtime,
76            flight_compression,
77        }
78    }
79
80    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
81    pub(crate) async fn handle_request(
82        &self,
83        request: GreptimeRequest,
84        hints: Vec<(String, String)>,
85    ) -> Result<Output> {
86        let query = request.request.context(InvalidQuerySnafu {
87            reason: "Expecting non-empty GreptimeRequest.",
88        })?;
89
90        let header = request.header.as_ref();
91        let query_ctx = create_query_context(header, hints)?;
92        let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
93        query_ctx.set_current_user(user_info);
94
95        let handler = self.handler.clone();
96        let request_type = request_type(&query).to_string();
97        let db = query_ctx.get_db_string();
98        let timer = RequestTimer::new(db.clone(), request_type);
99        let tracing_context = TracingContext::from_current_span();
100
101        let result_future = async move {
102            handler
103                .do_query(query, query_ctx)
104                .trace(tracing_context.attach(tracing::info_span!(
105                    "GreptimeRequestHandler::handle_request_runtime"
106                )))
107                .await
108                .map_err(|e| {
109                    if e.status_code().should_log_error() {
110                        let root_error = e.root_cause().unwrap_or(&e);
111                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
112                    } else {
113                        // Currently, we still print a debug log.
114                        debug!("Failed to handle request, err: {:?}", e);
115                    }
116                    e
117                })
118        };
119
120        match &self.runtime {
121            Some(runtime) => {
122                // Executes requests in another runtime to
123                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
124                //   - Refer to our blog for the rational behind it:
125                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
126                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
127                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
128                // 2. avoid the handler blocks the gRPC runtime incidentally.
129                runtime
130                    .spawn(result_future)
131                    .await
132                    .context(JoinTaskSnafu)
133                    .inspect_err(|e| {
134                        timer.record(e.status_code());
135                    })?
136            }
137            None => result_future.await,
138        }
139    }
140
141    pub(crate) async fn put_record_batches(
142        &self,
143        mut stream: PutRecordBatchRequestStream,
144        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
145    ) {
146        let handler = self.handler.clone();
147        let runtime = self
148            .runtime
149            .clone()
150            .unwrap_or_else(common_runtime::global_runtime);
151        runtime.spawn(async move {
152            // Cached table ref
153            let mut table_ref: Option<TableRef> = None;
154
155            let mut decoder = FlightDecoder::default();
156            while let Some(request) = stream.next().await {
157                let request = match request {
158                    Ok(request) => request,
159                    Err(e) => {
160                        let _ = result_sender.try_send(Err(e));
161                        break;
162                    }
163                };
164                let PutRecordBatchRequest {
165                    table_name,
166                    request_id,
167                    data,
168                } = request;
169
170                let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
171                let result = handler
172                    .put_record_batch(&table_name, &mut table_ref, &mut decoder, data)
173                    .await
174                    .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
175                timer.observe_duration();
176                let result = result
177                    .map(|x| DoPutResponse::new(request_id, x))
178                    .map_err(Into::into);
179                if result_sender.try_send(result).is_err() {
180                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
181                    break;
182                }
183            }
184        });
185    }
186
187    pub(crate) async fn validate_auth(
188        &self,
189        username_and_password: Option<&str>,
190        db: Option<&str>,
191    ) -> Result<bool> {
192        if self.user_provider.is_none() {
193            return Ok(true);
194        }
195
196        let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
197        let username_and_password = BASE64_STANDARD
198            .decode(username_and_password)
199            .context(InvalidBase64ValueSnafu)
200            .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
201
202        let mut split = username_and_password.splitn(2, ':');
203        let (username, password) = match (split.next(), split.next()) {
204            (Some(username), Some(password)) => (username, password),
205            (Some(username), None) => (username, ""),
206            (None, None) => return Ok(false),
207            _ => unreachable!(), // because this iterator won't yield Some after None
208        };
209
210        let (catalog, schema) = if let Some(db) = db {
211            parse_catalog_and_schema_from_db_string(db)
212        } else {
213            (
214                DEFAULT_CATALOG_NAME.to_string(),
215                DEFAULT_SCHEMA_NAME.to_string(),
216            )
217        };
218        let header = RequestHeader {
219            authorization: Some(AuthHeader {
220                auth_scheme: Some(AuthScheme::Basic(Basic {
221                    username: username.to_string(),
222                    password: password.to_string(),
223                })),
224            }),
225            catalog,
226            schema,
227            ..Default::default()
228        };
229
230        Ok(auth(
231            self.user_provider.clone(),
232            Some(&header),
233            &QueryContext::arc(),
234        )
235        .await
236        .is_ok())
237    }
238}
239
240pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
241    request
242        .request
243        .as_ref()
244        .map(request_type)
245        .unwrap_or_default()
246}
247
248pub(crate) async fn auth(
249    user_provider: Option<UserProviderRef>,
250    header: Option<&RequestHeader>,
251    query_ctx: &QueryContextRef,
252) -> Result<UserInfoRef> {
253    let Some(user_provider) = user_provider else {
254        return Ok(auth::userinfo_by_name(None));
255    };
256
257    let auth_scheme = header
258        .and_then(|header| {
259            header
260                .authorization
261                .as_ref()
262                .and_then(|x| x.auth_scheme.clone())
263        })
264        .context(NotFoundAuthHeaderSnafu)?;
265
266    match auth_scheme {
267        AuthScheme::Basic(Basic { username, password }) => user_provider
268            .auth(
269                Identity::UserId(&username, None),
270                Password::PlainText(password.into()),
271                query_ctx.current_catalog(),
272                &query_ctx.current_schema(),
273            )
274            .await
275            .context(AuthSnafu),
276        AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
277            name: "Token AuthScheme".to_string(),
278        }),
279    }
280    .inspect_err(|e| {
281        METRIC_AUTH_FAILURE
282            .with_label_values(&[e.status_code().as_ref()])
283            .inc();
284    })
285}
286
287pub(crate) fn create_query_context(
288    header: Option<&RequestHeader>,
289    mut extensions: Vec<(String, String)>,
290) -> Result<QueryContextRef> {
291    let (catalog, schema) = header
292        .map(|header| {
293            // We provide dbname field in newer versions of protos/sdks
294            // parse dbname from header in priority
295            if !header.dbname.is_empty() {
296                parse_catalog_and_schema_from_db_string(&header.dbname)
297            } else {
298                (
299                    if !header.catalog.is_empty() {
300                        header.catalog.to_lowercase()
301                    } else {
302                        DEFAULT_CATALOG_NAME.to_string()
303                    },
304                    if !header.schema.is_empty() {
305                        header.schema.to_lowercase()
306                    } else {
307                        DEFAULT_SCHEMA_NAME.to_string()
308                    },
309                )
310            }
311        })
312        .unwrap_or_else(|| {
313            (
314                DEFAULT_CATALOG_NAME.to_string(),
315                DEFAULT_SCHEMA_NAME.to_string(),
316            )
317        });
318    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
319    let mut ctx_builder = QueryContextBuilder::default()
320        .current_catalog(catalog)
321        .current_schema(schema)
322        .timezone(timezone);
323
324    if let Some(x) = extensions
325        .iter()
326        .position(|(k, _)| k == READ_PREFERENCE_HINT)
327    {
328        let (k, v) = extensions.swap_remove(x);
329        let Ok(read_preference) = ReadPreference::from_str(&v) else {
330            return UnknownHintSnafu {
331                hint: format!("{k}={v}"),
332            }
333            .fail();
334        };
335        ctx_builder = ctx_builder.read_preference(read_preference);
336    }
337
338    for (key, value) in extensions {
339        ctx_builder = ctx_builder.set_extension(key, value);
340    }
341    Ok(ctx_builder.build().into())
342}
343
344/// Histogram timer for handling gRPC request.
345///
346/// The timer records the elapsed time with [StatusCode::Success] on drop.
347pub(crate) struct RequestTimer {
348    start: Instant,
349    db: String,
350    request_type: String,
351    status_code: StatusCode,
352}
353
354impl RequestTimer {
355    /// Returns a new timer.
356    pub fn new(db: String, request_type: String) -> RequestTimer {
357        RequestTimer {
358            start: Instant::now(),
359            db,
360            request_type,
361            status_code: StatusCode::Success,
362        }
363    }
364
365    /// Consumes the timer and record the elapsed time with specific `status_code`.
366    pub fn record(mut self, status_code: StatusCode) {
367        self.status_code = status_code;
368    }
369}
370
371impl Drop for RequestTimer {
372    fn drop(&mut self) {
373        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
374            .with_label_values(&[
375                self.db.as_str(),
376                self.request_type.as_str(),
377                self.status_code.as_ref(),
378            ])
379            .observe(self.start.elapsed().as_secs_f64());
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use chrono::FixedOffset;
386    use common_time::Timezone;
387
388    use super::*;
389
390    #[test]
391    fn test_create_query_context() {
392        let header = RequestHeader {
393            catalog: "cat-a-log".to_string(),
394            timezone: "+01:00".to_string(),
395            ..Default::default()
396        };
397        let query_context = create_query_context(
398            Some(&header),
399            vec![
400                ("auto_create_table".to_string(), "true".to_string()),
401                ("read_preference".to_string(), "leader".to_string()),
402            ],
403        )
404        .unwrap();
405        assert_eq!(query_context.current_catalog(), "cat-a-log");
406        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
407        assert_eq!(
408            query_context.timezone(),
409            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
410        );
411        assert!(matches!(
412            query_context.read_preference(),
413            ReadPreference::Leader
414        ));
415        assert_eq!(
416            query_context.extensions().into_iter().collect::<Vec<_>>(),
417            vec![("auto_create_table".to_string(), "true".to_string())]
418        );
419    }
420}