Skip to main content

servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::collections::HashMap;
18use std::str::FromStr;
19use std::sync::{Arc, RwLock};
20use std::time::Instant;
21
22use api::helper::request_type;
23use api::v1::{GreptimeRequest, RequestHeader};
24use auth::UserProviderRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::parse_catalog_and_schema_from_db_string;
27use common_error::ext::ErrorExt;
28use common_error::status_code::StatusCode;
29use common_grpc::flight::do_put::DoPutResponse;
30use common_query::Output;
31use common_runtime::Runtime;
32use common_runtime::runtime::RuntimeTrait;
33use common_session::ReadPreference;
34use common_telemetry::tracing_context::{FutureExt, TracingContext};
35use common_telemetry::{debug, error, tracing, warn};
36use common_time::timezone::parse_timezone;
37use futures_util::StreamExt;
38use session::context::{Channel, QueryContextBuilder, QueryContextRef};
39use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key};
40use snafu::{OptionExt, ResultExt};
41use tokio::sync::mpsc;
42use tokio::sync::mpsc::error::TrySendError;
43use tonic::Status;
44
45use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
46use crate::grpc::flight::PutRecordBatchRequestStream;
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
49use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
50
51#[derive(Clone)]
52pub struct GreptimeRequestHandler {
53    handler: ServerGrpcQueryHandlerRef,
54    pub(crate) user_provider: Option<UserProviderRef>,
55    runtime: Option<Runtime>,
56    pub(crate) flight_compression: FlightCompression,
57}
58
59impl GreptimeRequestHandler {
60    pub fn new(
61        handler: ServerGrpcQueryHandlerRef,
62        user_provider: Option<UserProviderRef>,
63        runtime: Option<Runtime>,
64        flight_compression: FlightCompression,
65    ) -> Self {
66        Self {
67            handler,
68            user_provider,
69            runtime,
70            flight_compression,
71        }
72    }
73
74    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
75    pub(crate) async fn handle_request(
76        &self,
77        request: GreptimeRequest,
78        hints: Vec<(String, String)>,
79    ) -> Result<Output> {
80        let header = request.header.as_ref();
81        let query_ctx = create_query_context(Channel::Grpc, header, hints, HashMap::new())?;
82        self.handle_request_with_query_ctx(request, query_ctx).await
83    }
84
85    pub(crate) async fn handle_request_with_query_ctx(
86        &self,
87        request: GreptimeRequest,
88        query_ctx: QueryContextRef,
89    ) -> Result<Output> {
90        let query = request.request.context(InvalidQuerySnafu {
91            reason: "Expecting non-empty GreptimeRequest.",
92        })?;
93
94        let header = request.header.as_ref();
95        let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
96        query_ctx.set_current_user(user_info);
97
98        let handler = self.handler.clone();
99        let request_type = request_type(&query).to_string();
100        let db = query_ctx.get_db_string();
101        let timer = RequestTimer::new(db.clone(), request_type);
102        let tracing_context = TracingContext::from_current_span();
103
104        let result_future = async move {
105            handler
106                .do_query(query, query_ctx)
107                .trace(tracing_context.attach(tracing::info_span!(
108                    "GreptimeRequestHandler::handle_request_runtime"
109                )))
110                .await
111                .map_err(|e| {
112                    if e.status_code().should_log_error() {
113                        let root_error = e.root_cause().unwrap_or(&e);
114                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
115                    } else {
116                        // Currently, we still print a debug log.
117                        debug!("Failed to handle request, err: {:?}", e);
118                    }
119                    e
120                })
121        };
122
123        match &self.runtime {
124            Some(runtime) => {
125                // Executes requests in another runtime to
126                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
127                //   - Refer to our blog for the rational behind it:
128                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
129                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
130                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
131                // 2. avoid the handler blocks the gRPC runtime incidentally.
132                runtime
133                    .spawn(result_future)
134                    .await
135                    .context(JoinTaskSnafu)
136                    .inspect_err(|e| {
137                        timer.record(e.status_code());
138                    })?
139            }
140            None => result_future.await,
141        }
142    }
143
144    pub(crate) async fn put_record_batches(
145        &self,
146        stream: PutRecordBatchRequestStream,
147        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
148        query_ctx: QueryContextRef,
149    ) {
150        let handler = self.handler.clone();
151        let runtime = self
152            .runtime
153            .clone()
154            .unwrap_or_else(common_runtime::global_runtime);
155        runtime.spawn(async move {
156            let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx);
157
158            while let Some(result) = result_stream.next().await {
159                match &result {
160                    Ok(response) => {
161                        // Record the elapsed time metric from the response
162                        metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs());
163                    }
164                    Err(e) => {
165                        error!(e; "Failed to handle flight record batches");
166                    }
167                }
168
169                if let Err(e) =
170                    result_sender.try_send(result.map_err(|e| Status::from_error(Box::new(e))))
171                    && let TrySendError::Closed(_) = e
172                {
173                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
174                    break;
175                }
176            }
177        });
178    }
179}
180
181pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
182    request
183        .request
184        .as_ref()
185        .map(request_type)
186        .unwrap_or_default()
187}
188
189/// Creates a new `QueryContext` from the provided request header and extensions.
190/// Strongly recommend setting an appropriate channel, as this is very helpful for statistics.
191pub(crate) fn create_query_context(
192    channel: Channel,
193    header: Option<&RequestHeader>,
194    mut extensions: Vec<(String, String)>,
195    snapshot_seqs: HashMap<u64, u64>,
196) -> Result<QueryContextRef> {
197    let (catalog, schema) = header
198        .map(|header| {
199            // We provide dbname field in newer versions of protos/sdks
200            // parse dbname from header in priority
201            if !header.dbname.is_empty() {
202                parse_catalog_and_schema_from_db_string(&header.dbname)
203            } else {
204                (
205                    if !header.catalog.is_empty() {
206                        header.catalog.to_lowercase()
207                    } else {
208                        DEFAULT_CATALOG_NAME.to_string()
209                    },
210                    if !header.schema.is_empty() {
211                        header.schema.to_lowercase()
212                    } else {
213                        DEFAULT_SCHEMA_NAME.to_string()
214                    },
215                )
216            }
217        })
218        .unwrap_or_else(|| {
219            (
220                DEFAULT_CATALOG_NAME.to_string(),
221                DEFAULT_SCHEMA_NAME.to_string(),
222            )
223        });
224    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
225    let mut ctx_builder = QueryContextBuilder::default()
226        .current_catalog(catalog)
227        .current_schema(schema)
228        .timezone(timezone)
229        .channel(channel)
230        .snapshot_seqs(Arc::new(RwLock::new(snapshot_seqs)));
231
232    if let Some(x) = extensions
233        .iter()
234        .position(|(k, _)| k == READ_PREFERENCE_HINT)
235    {
236        let (k, v) = extensions.swap_remove(x);
237        let Ok(read_preference) = ReadPreference::from_str(&v) else {
238            return UnknownHintSnafu {
239                hint: format!("{k}={v}"),
240            }
241            .fail();
242        };
243        ctx_builder = ctx_builder.read_preference(read_preference);
244    }
245
246    for (key, value) in extensions {
247        if is_reserved_extension_key(&key) {
248            debug!(
249                key = key.as_str(),
250                "Ignoring reserved external query context extension key"
251            );
252            continue;
253        }
254        ctx_builder = ctx_builder.set_extension(key, value);
255    }
256    Ok(ctx_builder.build().into())
257}
258
259/// Histogram timer for handling gRPC request.
260///
261/// The timer records the elapsed time with [StatusCode::Success] on drop.
262pub(crate) struct RequestTimer {
263    start: Instant,
264    db: String,
265    request_type: String,
266    status_code: StatusCode,
267}
268
269impl RequestTimer {
270    /// Returns a new timer.
271    pub fn new(db: String, request_type: String) -> RequestTimer {
272        RequestTimer {
273            start: Instant::now(),
274            db,
275            request_type,
276            status_code: StatusCode::Success,
277        }
278    }
279
280    /// Consumes the timer and record the elapsed time with specific `status_code`.
281    pub fn record(mut self, status_code: StatusCode) {
282        self.status_code = status_code;
283    }
284}
285
286impl Drop for RequestTimer {
287    fn drop(&mut self) {
288        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
289            .with_label_values(&[
290                self.db.as_str(),
291                self.request_type.as_str(),
292                self.status_code.as_ref(),
293            ])
294            .observe(self.start.elapsed().as_secs_f64());
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use chrono::FixedOffset;
301    use common_time::Timezone;
302    use session::hints::{
303        INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, REMOTE_QUERY_ID_EXTENSION_KEY,
304    };
305
306    use super::*;
307
308    #[test]
309    fn test_create_query_context() {
310        let header = RequestHeader {
311            catalog: "cat-a-log".to_string(),
312            timezone: "+01:00".to_string(),
313            ..Default::default()
314        };
315        let query_context = create_query_context(
316            Channel::Unknown,
317            Some(&header),
318            vec![
319                ("auto_create_table".to_string(), "true".to_string()),
320                ("read_preference".to_string(), "leader".to_string()),
321                (
322                    REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
323                    "spoofed".to_string(),
324                ),
325                (
326                    INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY.to_string(),
327                    "spoofed-regs".to_string(),
328                ),
329            ],
330            HashMap::from([(7, 88)]),
331        )
332        .unwrap();
333        assert_eq!(query_context.get_snapshot(7), Some(88));
334        assert_eq!(query_context.current_catalog(), "cat-a-log");
335        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
336        assert_eq!(
337            query_context.timezone(),
338            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
339        );
340        assert!(matches!(
341            query_context.read_preference(),
342            ReadPreference::Leader
343        ));
344        let mut extensions = query_context.extensions().into_iter().collect::<Vec<_>>();
345        extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0));
346        assert_eq!(
347            extensions[0],
348            ("auto_create_table".to_string(), "true".to_string())
349        );
350        assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string());
351        assert_eq!(
352            query_context.remote_query_id(),
353            Some(extensions[1].1.as_str())
354        );
355        assert_ne!(query_context.remote_query_id(), Some("spoofed"));
356        assert!(
357            query_context
358                .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY)
359                .is_none()
360        );
361    }
362
363    #[test]
364    fn test_create_query_context_ignores_remote_query_id_extension() {
365        let query_context = create_query_context(
366            Channel::Grpc,
367            None,
368            vec![(
369                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
370                "spoofed-query-id".to_string(),
371            )],
372            HashMap::new(),
373        )
374        .unwrap();
375
376        assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
377        assert_eq!(
378            query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
379            query_context.remote_query_id()
380        );
381    }
382}