1use std::collections::HashMap;
18use std::str::FromStr;
19use std::sync::{Arc, RwLock};
20use std::time::Instant;
21
22use api::helper::request_type;
23use api::v1::{GreptimeRequest, RequestHeader};
24use auth::UserProviderRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::parse_catalog_and_schema_from_db_string;
27use common_error::ext::ErrorExt;
28use common_error::status_code::StatusCode;
29use common_grpc::flight::do_put::DoPutResponse;
30use common_query::Output;
31use common_runtime::Runtime;
32use common_runtime::runtime::RuntimeTrait;
33use common_session::ReadPreference;
34use common_telemetry::tracing_context::{FutureExt, TracingContext};
35use common_telemetry::{debug, error, tracing, warn};
36use common_time::timezone::parse_timezone;
37use futures_util::StreamExt;
38use session::context::{Channel, QueryContextBuilder, QueryContextRef};
39use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key};
40use snafu::{OptionExt, ResultExt};
41use tokio::sync::mpsc;
42use tokio::sync::mpsc::error::TrySendError;
43use tonic::Status;
44
45use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
46use crate::grpc::flight::PutRecordBatchRequestStream;
47use crate::grpc::{FlightCompression, TonicResult, context_auth};
48use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
49use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
50
51#[derive(Clone)]
52pub struct GreptimeRequestHandler {
53 handler: ServerGrpcQueryHandlerRef,
54 pub(crate) user_provider: Option<UserProviderRef>,
55 runtime: Option<Runtime>,
56 pub(crate) flight_compression: FlightCompression,
57}
58
59impl GreptimeRequestHandler {
60 pub fn new(
61 handler: ServerGrpcQueryHandlerRef,
62 user_provider: Option<UserProviderRef>,
63 runtime: Option<Runtime>,
64 flight_compression: FlightCompression,
65 ) -> Self {
66 Self {
67 handler,
68 user_provider,
69 runtime,
70 flight_compression,
71 }
72 }
73
74 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
75 pub(crate) async fn handle_request(
76 &self,
77 request: GreptimeRequest,
78 hints: Vec<(String, String)>,
79 ) -> Result<Output> {
80 let header = request.header.as_ref();
81 let query_ctx = create_query_context(Channel::Grpc, header, hints, HashMap::new())?;
82 self.handle_request_with_query_ctx(request, query_ctx).await
83 }
84
85 pub(crate) async fn handle_request_with_query_ctx(
86 &self,
87 request: GreptimeRequest,
88 query_ctx: QueryContextRef,
89 ) -> Result<Output> {
90 let query = request.request.context(InvalidQuerySnafu {
91 reason: "Expecting non-empty GreptimeRequest.",
92 })?;
93
94 let header = request.header.as_ref();
95 let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
96 query_ctx.set_current_user(user_info);
97
98 let handler = self.handler.clone();
99 let request_type = request_type(&query).to_string();
100 let db = query_ctx.get_db_string();
101 let timer = RequestTimer::new(db.clone(), request_type);
102 let tracing_context = TracingContext::from_current_span();
103
104 let result_future = async move {
105 handler
106 .do_query(query, query_ctx)
107 .trace(tracing_context.attach(tracing::info_span!(
108 "GreptimeRequestHandler::handle_request_runtime"
109 )))
110 .await
111 .map_err(|e| {
112 if e.status_code().should_log_error() {
113 let root_error = e.root_cause().unwrap_or(&e);
114 error!(e; "Failed to handle request, error: {}", root_error.to_string());
115 } else {
116 debug!("Failed to handle request, err: {:?}", e);
118 }
119 e
120 })
121 };
122
123 match &self.runtime {
124 Some(runtime) => {
125 runtime
133 .spawn(result_future)
134 .await
135 .context(JoinTaskSnafu)
136 .inspect_err(|e| {
137 timer.record(e.status_code());
138 })?
139 }
140 None => result_future.await,
141 }
142 }
143
144 pub(crate) async fn put_record_batches(
145 &self,
146 stream: PutRecordBatchRequestStream,
147 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
148 query_ctx: QueryContextRef,
149 ) {
150 let handler = self.handler.clone();
151 let runtime = self
152 .runtime
153 .clone()
154 .unwrap_or_else(common_runtime::global_runtime);
155 runtime.spawn(async move {
156 let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx);
157
158 while let Some(result) = result_stream.next().await {
159 match &result {
160 Ok(response) => {
161 metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs());
163 }
164 Err(e) => {
165 error!(e; "Failed to handle flight record batches");
166 }
167 }
168
169 if let Err(e) =
170 result_sender.try_send(result.map_err(|e| Status::from_error(Box::new(e))))
171 && let TrySendError::Closed(_) = e
172 {
173 warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
174 break;
175 }
176 }
177 });
178 }
179}
180
181pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
182 request
183 .request
184 .as_ref()
185 .map(request_type)
186 .unwrap_or_default()
187}
188
189pub(crate) fn create_query_context(
192 channel: Channel,
193 header: Option<&RequestHeader>,
194 mut extensions: Vec<(String, String)>,
195 snapshot_seqs: HashMap<u64, u64>,
196) -> Result<QueryContextRef> {
197 let (catalog, schema) = header
198 .map(|header| {
199 if !header.dbname.is_empty() {
202 parse_catalog_and_schema_from_db_string(&header.dbname)
203 } else {
204 (
205 if !header.catalog.is_empty() {
206 header.catalog.to_lowercase()
207 } else {
208 DEFAULT_CATALOG_NAME.to_string()
209 },
210 if !header.schema.is_empty() {
211 header.schema.to_lowercase()
212 } else {
213 DEFAULT_SCHEMA_NAME.to_string()
214 },
215 )
216 }
217 })
218 .unwrap_or_else(|| {
219 (
220 DEFAULT_CATALOG_NAME.to_string(),
221 DEFAULT_SCHEMA_NAME.to_string(),
222 )
223 });
224 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
225 let mut ctx_builder = QueryContextBuilder::default()
226 .current_catalog(catalog)
227 .current_schema(schema)
228 .timezone(timezone)
229 .channel(channel)
230 .snapshot_seqs(Arc::new(RwLock::new(snapshot_seqs)));
231
232 if let Some(x) = extensions
233 .iter()
234 .position(|(k, _)| k == READ_PREFERENCE_HINT)
235 {
236 let (k, v) = extensions.swap_remove(x);
237 let Ok(read_preference) = ReadPreference::from_str(&v) else {
238 return UnknownHintSnafu {
239 hint: format!("{k}={v}"),
240 }
241 .fail();
242 };
243 ctx_builder = ctx_builder.read_preference(read_preference);
244 }
245
246 for (key, value) in extensions {
247 if is_reserved_extension_key(&key) {
248 debug!(
249 key = key.as_str(),
250 "Ignoring reserved external query context extension key"
251 );
252 continue;
253 }
254 ctx_builder = ctx_builder.set_extension(key, value);
255 }
256 Ok(ctx_builder.build().into())
257}
258
259pub(crate) struct RequestTimer {
263 start: Instant,
264 db: String,
265 request_type: String,
266 status_code: StatusCode,
267}
268
269impl RequestTimer {
270 pub fn new(db: String, request_type: String) -> RequestTimer {
272 RequestTimer {
273 start: Instant::now(),
274 db,
275 request_type,
276 status_code: StatusCode::Success,
277 }
278 }
279
280 pub fn record(mut self, status_code: StatusCode) {
282 self.status_code = status_code;
283 }
284}
285
286impl Drop for RequestTimer {
287 fn drop(&mut self) {
288 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
289 .with_label_values(&[
290 self.db.as_str(),
291 self.request_type.as_str(),
292 self.status_code.as_ref(),
293 ])
294 .observe(self.start.elapsed().as_secs_f64());
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use chrono::FixedOffset;
301 use common_time::Timezone;
302 use session::hints::{
303 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, REMOTE_QUERY_ID_EXTENSION_KEY,
304 };
305
306 use super::*;
307
308 #[test]
309 fn test_create_query_context() {
310 let header = RequestHeader {
311 catalog: "cat-a-log".to_string(),
312 timezone: "+01:00".to_string(),
313 ..Default::default()
314 };
315 let query_context = create_query_context(
316 Channel::Unknown,
317 Some(&header),
318 vec![
319 ("auto_create_table".to_string(), "true".to_string()),
320 ("read_preference".to_string(), "leader".to_string()),
321 (
322 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
323 "spoofed".to_string(),
324 ),
325 (
326 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY.to_string(),
327 "spoofed-regs".to_string(),
328 ),
329 ],
330 HashMap::from([(7, 88)]),
331 )
332 .unwrap();
333 assert_eq!(query_context.get_snapshot(7), Some(88));
334 assert_eq!(query_context.current_catalog(), "cat-a-log");
335 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
336 assert_eq!(
337 query_context.timezone(),
338 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
339 );
340 assert!(matches!(
341 query_context.read_preference(),
342 ReadPreference::Leader
343 ));
344 let mut extensions = query_context.extensions().into_iter().collect::<Vec<_>>();
345 extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0));
346 assert_eq!(
347 extensions[0],
348 ("auto_create_table".to_string(), "true".to_string())
349 );
350 assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string());
351 assert_eq!(
352 query_context.remote_query_id(),
353 Some(extensions[1].1.as_str())
354 );
355 assert_ne!(query_context.remote_query_id(), Some("spoofed"));
356 assert!(
357 query_context
358 .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY)
359 .is_none()
360 );
361 }
362
363 #[test]
364 fn test_create_query_context_ignores_remote_query_id_extension() {
365 let query_context = create_query_context(
366 Channel::Grpc,
367 None,
368 vec![(
369 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
370 "spoofed-query-id".to_string(),
371 )],
372 HashMap::new(),
373 )
374 .unwrap();
375
376 assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
377 assert_eq!(
378 query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
379 query_context.remote_query_id()
380 );
381 }
382}