Skip to main content

servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::collections::HashMap;
18use std::str::FromStr;
19use std::sync::{Arc, RwLock};
20use std::time::Instant;
21
22use api::helper::request_type;
23use api::v1::{GreptimeRequest, RequestHeader};
24use auth::UserProviderRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::parse_catalog_and_schema_from_db_string;
27use common_error::ext::ErrorExt;
28use common_error::status_code::StatusCode;
29use common_grpc::flight::do_put::DoPutResponse;
30use common_query::Output;
31use common_runtime::Runtime;
32use common_runtime::runtime::RuntimeTrait;
33use common_session::ReadPreference;
34use common_telemetry::tracing_context::{FutureExt, TracingContext};
35use common_telemetry::{debug, error, tracing, warn};
36use common_time::timezone::parse_timezone;
37use futures_util::StreamExt;
38use session::context::{Channel, QueryContextBuilder, QueryContextRef};
39use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key};
40use snafu::{OptionExt, ResultExt};
41use tokio::sync::mpsc;
42use tokio::sync::mpsc::error::TrySendError;
43use tonic::Status;
44
45use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
46use crate::grpc::flight::PutRecordBatchRequestStream;
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
49use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
50
51#[derive(Clone)]
52pub struct GreptimeRequestHandler {
53    handler: ServerGrpcQueryHandlerRef,
54    pub(crate) user_provider: Option<UserProviderRef>,
55    runtime: Option<Runtime>,
56    pub(crate) flight_compression: FlightCompression,
57}
58
59impl GreptimeRequestHandler {
60    pub fn new(
61        handler: ServerGrpcQueryHandlerRef,
62        user_provider: Option<UserProviderRef>,
63        runtime: Option<Runtime>,
64        flight_compression: FlightCompression,
65    ) -> Self {
66        Self {
67            handler,
68            user_provider,
69            runtime,
70            flight_compression,
71        }
72    }
73
74    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
75    pub(crate) async fn handle_request(
76        &self,
77        request: GreptimeRequest,
78        hints: Vec<(String, String)>,
79    ) -> Result<Output> {
80        let header = request.header.as_ref();
81        let query_ctx = create_query_context(Channel::Grpc, header, hints, HashMap::new())?;
82        self.handle_request_with_query_ctx(request, query_ctx).await
83    }
84
85    pub(crate) async fn handle_request_with_query_ctx(
86        &self,
87        request: GreptimeRequest,
88        query_ctx: QueryContextRef,
89    ) -> Result<Output> {
90        let query = request.request.context(InvalidQuerySnafu {
91            reason: "Expecting non-empty GreptimeRequest.",
92        })?;
93
94        let header = request.header.as_ref();
95        let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
96        query_ctx.set_current_user(user_info);
97
98        let handler = self.handler.clone();
99        let request_type = request_type(&query).to_string();
100        let db = query_ctx.get_db_string();
101        let timer = RequestTimer::new(db.clone(), request_type);
102        let tracing_context = TracingContext::from_current_span();
103
104        let result_future = async move {
105            handler
106                .do_query(query, query_ctx)
107                .trace(tracing_context.attach(tracing::info_span!(
108                    "GreptimeRequestHandler::handle_request_runtime"
109                )))
110                .await
111                .map_err(|e| {
112                    if e.status_code().should_log_error() {
113                        let root_error = e.root_cause().unwrap_or(&e);
114                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
115                    } else {
116                        // Currently, we still print a debug log.
117                        debug!("Failed to handle request, err: {:?}", e);
118                    }
119                    e
120                })
121        };
122
123        match &self.runtime {
124            Some(runtime) => {
125                // Executes requests in another runtime to
126                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
127                //   - Refer to our blog for the rational behind it:
128                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
129                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
130                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
131                // 2. avoid the handler blocks the gRPC runtime incidentally.
132                runtime
133                    .spawn(result_future)
134                    .await
135                    .context(JoinTaskSnafu)
136                    .inspect_err(|e| {
137                        timer.record(e.status_code());
138                    })?
139            }
140            None => result_future.await,
141        }
142    }
143
144    pub(crate) async fn put_record_batches(
145        &self,
146        stream: PutRecordBatchRequestStream,
147        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
148        query_ctx: QueryContextRef,
149    ) {
150        let handler = self.handler.clone();
151        let runtime = self
152            .runtime
153            .clone()
154            .unwrap_or_else(common_runtime::global_runtime);
155        runtime.spawn(async move {
156            let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx);
157
158            while let Some(result) = result_stream.next().await {
159                match &result {
160                    Ok(response) => {
161                        // Record the elapsed time metric from the response
162                        metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs());
163                    }
164                    Err(e) => {
165                        error!(e; "Failed to handle flight record batches");
166                    }
167                }
168
169                if let Err(e) = result_sender.try_send(result.map_err(Status::from))
170                    && let TrySendError::Closed(_) = e
171                {
172                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
173                    break;
174                }
175            }
176        });
177    }
178}
179
180pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
181    request
182        .request
183        .as_ref()
184        .map(request_type)
185        .unwrap_or_default()
186}
187
188/// Creates a new `QueryContext` from the provided request header and extensions.
189/// Strongly recommend setting an appropriate channel, as this is very helpful for statistics.
190pub(crate) fn create_query_context(
191    channel: Channel,
192    header: Option<&RequestHeader>,
193    mut extensions: Vec<(String, String)>,
194    snapshot_seqs: HashMap<u64, u64>,
195) -> Result<QueryContextRef> {
196    let (catalog, schema) = header
197        .map(|header| {
198            // We provide dbname field in newer versions of protos/sdks
199            // parse dbname from header in priority
200            if !header.dbname.is_empty() {
201                parse_catalog_and_schema_from_db_string(&header.dbname)
202            } else {
203                (
204                    if !header.catalog.is_empty() {
205                        header.catalog.to_lowercase()
206                    } else {
207                        DEFAULT_CATALOG_NAME.to_string()
208                    },
209                    if !header.schema.is_empty() {
210                        header.schema.to_lowercase()
211                    } else {
212                        DEFAULT_SCHEMA_NAME.to_string()
213                    },
214                )
215            }
216        })
217        .unwrap_or_else(|| {
218            (
219                DEFAULT_CATALOG_NAME.to_string(),
220                DEFAULT_SCHEMA_NAME.to_string(),
221            )
222        });
223    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
224    let mut ctx_builder = QueryContextBuilder::default()
225        .current_catalog(catalog)
226        .current_schema(schema)
227        .timezone(timezone)
228        .channel(channel)
229        .snapshot_seqs(Arc::new(RwLock::new(snapshot_seqs)));
230
231    if let Some(x) = extensions
232        .iter()
233        .position(|(k, _)| k == READ_PREFERENCE_HINT)
234    {
235        let (k, v) = extensions.swap_remove(x);
236        let Ok(read_preference) = ReadPreference::from_str(&v) else {
237            return UnknownHintSnafu {
238                hint: format!("{k}={v}"),
239            }
240            .fail();
241        };
242        ctx_builder = ctx_builder.read_preference(read_preference);
243    }
244
245    for (key, value) in extensions {
246        if is_reserved_extension_key(&key) {
247            debug!(
248                key = key.as_str(),
249                "Ignoring reserved external query context extension key"
250            );
251            continue;
252        }
253        ctx_builder = ctx_builder.set_extension(key, value);
254    }
255    Ok(ctx_builder.build().into())
256}
257
258/// Histogram timer for handling gRPC request.
259///
260/// The timer records the elapsed time with [StatusCode::Success] on drop.
261pub(crate) struct RequestTimer {
262    start: Instant,
263    db: String,
264    request_type: String,
265    status_code: StatusCode,
266}
267
268impl RequestTimer {
269    /// Returns a new timer.
270    pub fn new(db: String, request_type: String) -> RequestTimer {
271        RequestTimer {
272            start: Instant::now(),
273            db,
274            request_type,
275            status_code: StatusCode::Success,
276        }
277    }
278
279    /// Consumes the timer and record the elapsed time with specific `status_code`.
280    pub fn record(mut self, status_code: StatusCode) {
281        self.status_code = status_code;
282    }
283}
284
285impl Drop for RequestTimer {
286    fn drop(&mut self) {
287        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
288            .with_label_values(&[
289                self.db.as_str(),
290                self.request_type.as_str(),
291                self.status_code.as_ref(),
292            ])
293            .observe(self.start.elapsed().as_secs_f64());
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use chrono::FixedOffset;
300    use common_error::ext::BoxedError;
301    use common_error::{GREPTIME_DB_HEADER_ERROR_CODE, GREPTIME_DB_HEADER_ERROR_RETRY_HINT};
302    use common_time::Timezone;
303    use query::options::FLOW_SCHEDULED_TIME_MILLIS;
304    use session::hints::{
305        INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, REMOTE_QUERY_ID_EXTENSION_KEY,
306    };
307    use snafu::ResultExt;
308    use tonic::Code;
309
310    use super::*;
311    use crate::error::{ExecuteGrpcRequestSnafu, InvalidParameterSnafu};
312
313    #[test]
314    fn test_create_query_context() {
315        let header = RequestHeader {
316            catalog: "cat-a-log".to_string(),
317            timezone: "+01:00".to_string(),
318            ..Default::default()
319        };
320        let query_context = create_query_context(
321            Channel::Unknown,
322            Some(&header),
323            vec![
324                ("auto_create_table".to_string(), "true".to_string()),
325                ("read_preference".to_string(), "leader".to_string()),
326                (
327                    REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
328                    "spoofed".to_string(),
329                ),
330                (
331                    INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY.to_string(),
332                    "spoofed-regs".to_string(),
333                ),
334                (
335                    FLOW_SCHEDULED_TIME_MILLIS.to_string(),
336                    "1700000000000".to_string(),
337                ),
338            ],
339            HashMap::from([(7, 88)]),
340        )
341        .unwrap();
342        assert_eq!(query_context.get_snapshot(7), Some(88));
343        assert_eq!(query_context.current_catalog(), "cat-a-log");
344        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
345        assert_eq!(
346            query_context.timezone(),
347            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
348        );
349        assert!(matches!(
350            query_context.read_preference(),
351            ReadPreference::Leader
352        ));
353        assert_eq!(query_context.extension("auto_create_table"), Some("true"));
354        assert_ne!(query_context.remote_query_id(), Some("spoofed"));
355        assert!(
356            query_context
357                .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY)
358                .is_none()
359        );
360        assert_eq!(
361            query_context.extension(FLOW_SCHEDULED_TIME_MILLIS),
362            Some("1700000000000")
363        );
364    }
365
366    #[test]
367    fn test_create_query_context_ignores_remote_query_id_extension() {
368        let query_context = create_query_context(
369            Channel::Grpc,
370            None,
371            vec![(
372                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
373                "spoofed-query-id".to_string(),
374            )],
375            HashMap::new(),
376        )
377        .unwrap();
378
379        assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
380        assert_eq!(
381            query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
382            query_context.remote_query_id()
383        );
384    }
385
386    #[test]
387    fn test_record_batch_error_to_status_preserves_error_details() {
388        let inner = InvalidParameterSnafu {
389            reason: "Column not found, column: new_col",
390        }
391        .build();
392        let err = Err::<(), _>(BoxedError::new(inner))
393            .context(ExecuteGrpcRequestSnafu)
394            .unwrap_err();
395
396        let status = Status::from(err);
397
398        assert_eq!(status.code(), Code::InvalidArgument);
399        assert!(
400            status
401                .message()
402                .contains("Column not found, column: new_col")
403        );
404        assert!(
405            status
406                .message()
407                .contains("Invalid request parameter: Column not found")
408        );
409        assert!(
410            status
411                .metadata()
412                .contains_key(GREPTIME_DB_HEADER_ERROR_CODE)
413        );
414        assert!(
415            status
416                .metadata()
417                .contains_key(GREPTIME_DB_HEADER_ERROR_RETRY_HINT)
418        );
419    }
420}