1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{Channel, QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::TableRef;
44use tokio::sync::mpsc;
45use tokio::sync::mpsc::error::TrySendError;
46
47use crate::error::Error::UnsupportedAuthScheme;
48use crate::error::{
49 AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
50 JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
51};
52use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
53use crate::grpc::{FlightCompression, TonicResult};
54use crate::metrics;
55use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
56use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
57
58#[derive(Clone)]
59pub struct GreptimeRequestHandler {
60 handler: ServerGrpcQueryHandlerRef,
61 user_provider: Option<UserProviderRef>,
62 runtime: Option<Runtime>,
63 pub(crate) flight_compression: FlightCompression,
64}
65
66impl GreptimeRequestHandler {
67 pub fn new(
68 handler: ServerGrpcQueryHandlerRef,
69 user_provider: Option<UserProviderRef>,
70 runtime: Option<Runtime>,
71 flight_compression: FlightCompression,
72 ) -> Self {
73 Self {
74 handler,
75 user_provider,
76 runtime,
77 flight_compression,
78 }
79 }
80
81 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
82 pub(crate) async fn handle_request(
83 &self,
84 request: GreptimeRequest,
85 hints: Vec<(String, String)>,
86 ) -> Result<Output> {
87 let query = request.request.context(InvalidQuerySnafu {
88 reason: "Expecting non-empty GreptimeRequest.",
89 })?;
90
91 let header = request.header.as_ref();
92 let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
93 let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
94 query_ctx.set_current_user(user_info);
95
96 let handler = self.handler.clone();
97 let request_type = request_type(&query).to_string();
98 let db = query_ctx.get_db_string();
99 let timer = RequestTimer::new(db.clone(), request_type);
100 let tracing_context = TracingContext::from_current_span();
101
102 let result_future = async move {
103 handler
104 .do_query(query, query_ctx)
105 .trace(tracing_context.attach(tracing::info_span!(
106 "GreptimeRequestHandler::handle_request_runtime"
107 )))
108 .await
109 .map_err(|e| {
110 if e.status_code().should_log_error() {
111 let root_error = e.root_cause().unwrap_or(&e);
112 error!(e; "Failed to handle request, error: {}", root_error.to_string());
113 } else {
114 debug!("Failed to handle request, err: {:?}", e);
116 }
117 e
118 })
119 };
120
121 match &self.runtime {
122 Some(runtime) => {
123 runtime
131 .spawn(result_future)
132 .await
133 .context(JoinTaskSnafu)
134 .inspect_err(|e| {
135 timer.record(e.status_code());
136 })?
137 }
138 None => result_future.await,
139 }
140 }
141
142 pub(crate) async fn put_record_batches(
143 &self,
144 mut stream: PutRecordBatchRequestStream,
145 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
146 ) {
147 let handler = self.handler.clone();
148 let runtime = self
149 .runtime
150 .clone()
151 .unwrap_or_else(common_runtime::global_runtime);
152 runtime.spawn(async move {
153 let mut table_ref: Option<TableRef> = None;
155
156 let mut decoder = FlightDecoder::default();
157 while let Some(request) = stream.next().await {
158 let request = match request {
159 Ok(request) => request,
160 Err(e) => {
161 let _ = result_sender.try_send(Err(e));
162 break;
163 }
164 };
165 let PutRecordBatchRequest {
166 table_name,
167 request_id,
168 data,
169 } = request;
170
171 let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
172 let result = handler
173 .put_record_batch(&table_name, &mut table_ref, &mut decoder, data)
174 .await
175 .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
176 timer.observe_duration();
177 let result = result
178 .map(|x| DoPutResponse::new(request_id, x))
179 .map_err(Into::into);
180 if let Err(e)= result_sender.try_send(result)
181 && let TrySendError::Closed(_) = e {
182 warn!(r#""DoPut" client with request_id {} maybe unreachable, abort handling its message"#, request_id);
183 break;
184 }
185 }
186 });
187 }
188
189 pub(crate) async fn validate_auth(
190 &self,
191 username_and_password: Option<&str>,
192 db: Option<&str>,
193 ) -> Result<bool> {
194 if self.user_provider.is_none() {
195 return Ok(true);
196 }
197
198 let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
199 let username_and_password = BASE64_STANDARD
200 .decode(username_and_password)
201 .context(InvalidBase64ValueSnafu)
202 .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
203
204 let mut split = username_and_password.splitn(2, ':');
205 let (username, password) = match (split.next(), split.next()) {
206 (Some(username), Some(password)) => (username, password),
207 (Some(username), None) => (username, ""),
208 (None, None) => return Ok(false),
209 _ => unreachable!(), };
211
212 let (catalog, schema) = if let Some(db) = db {
213 parse_catalog_and_schema_from_db_string(db)
214 } else {
215 (
216 DEFAULT_CATALOG_NAME.to_string(),
217 DEFAULT_SCHEMA_NAME.to_string(),
218 )
219 };
220 let header = RequestHeader {
221 authorization: Some(AuthHeader {
222 auth_scheme: Some(AuthScheme::Basic(Basic {
223 username: username.to_string(),
224 password: password.to_string(),
225 })),
226 }),
227 catalog,
228 schema,
229 ..Default::default()
230 };
231
232 Ok(auth(
233 self.user_provider.clone(),
234 Some(&header),
235 &QueryContext::arc(),
236 )
237 .await
238 .is_ok())
239 }
240}
241
242pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
243 request
244 .request
245 .as_ref()
246 .map(request_type)
247 .unwrap_or_default()
248}
249
250pub(crate) async fn auth(
251 user_provider: Option<UserProviderRef>,
252 header: Option<&RequestHeader>,
253 query_ctx: &QueryContextRef,
254) -> Result<UserInfoRef> {
255 let Some(user_provider) = user_provider else {
256 return Ok(auth::userinfo_by_name(None));
257 };
258
259 let auth_scheme = header
260 .and_then(|header| {
261 header
262 .authorization
263 .as_ref()
264 .and_then(|x| x.auth_scheme.clone())
265 })
266 .context(NotFoundAuthHeaderSnafu)?;
267
268 match auth_scheme {
269 AuthScheme::Basic(Basic { username, password }) => user_provider
270 .auth(
271 Identity::UserId(&username, None),
272 Password::PlainText(password.into()),
273 query_ctx.current_catalog(),
274 &query_ctx.current_schema(),
275 )
276 .await
277 .context(AuthSnafu),
278 AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
279 name: "Token AuthScheme".to_string(),
280 }),
281 }
282 .inspect_err(|e| {
283 METRIC_AUTH_FAILURE
284 .with_label_values(&[e.status_code().as_ref()])
285 .inc();
286 })
287}
288
289pub(crate) fn create_query_context(
292 channel: Channel,
293 header: Option<&RequestHeader>,
294 mut extensions: Vec<(String, String)>,
295) -> Result<QueryContextRef> {
296 let (catalog, schema) = header
297 .map(|header| {
298 if !header.dbname.is_empty() {
301 parse_catalog_and_schema_from_db_string(&header.dbname)
302 } else {
303 (
304 if !header.catalog.is_empty() {
305 header.catalog.to_lowercase()
306 } else {
307 DEFAULT_CATALOG_NAME.to_string()
308 },
309 if !header.schema.is_empty() {
310 header.schema.to_lowercase()
311 } else {
312 DEFAULT_SCHEMA_NAME.to_string()
313 },
314 )
315 }
316 })
317 .unwrap_or_else(|| {
318 (
319 DEFAULT_CATALOG_NAME.to_string(),
320 DEFAULT_SCHEMA_NAME.to_string(),
321 )
322 });
323 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
324 let mut ctx_builder = QueryContextBuilder::default()
325 .current_catalog(catalog)
326 .current_schema(schema)
327 .timezone(timezone)
328 .channel(channel);
329
330 if let Some(x) = extensions
331 .iter()
332 .position(|(k, _)| k == READ_PREFERENCE_HINT)
333 {
334 let (k, v) = extensions.swap_remove(x);
335 let Ok(read_preference) = ReadPreference::from_str(&v) else {
336 return UnknownHintSnafu {
337 hint: format!("{k}={v}"),
338 }
339 .fail();
340 };
341 ctx_builder = ctx_builder.read_preference(read_preference);
342 }
343
344 for (key, value) in extensions {
345 ctx_builder = ctx_builder.set_extension(key, value);
346 }
347 Ok(ctx_builder.build().into())
348}
349
350pub(crate) struct RequestTimer {
354 start: Instant,
355 db: String,
356 request_type: String,
357 status_code: StatusCode,
358}
359
360impl RequestTimer {
361 pub fn new(db: String, request_type: String) -> RequestTimer {
363 RequestTimer {
364 start: Instant::now(),
365 db,
366 request_type,
367 status_code: StatusCode::Success,
368 }
369 }
370
371 pub fn record(mut self, status_code: StatusCode) {
373 self.status_code = status_code;
374 }
375}
376
377impl Drop for RequestTimer {
378 fn drop(&mut self) {
379 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
380 .with_label_values(&[
381 self.db.as_str(),
382 self.request_type.as_str(),
383 self.status_code.as_ref(),
384 ])
385 .observe(self.start.elapsed().as_secs_f64());
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use chrono::FixedOffset;
392 use common_time::Timezone;
393
394 use super::*;
395
396 #[test]
397 fn test_create_query_context() {
398 let header = RequestHeader {
399 catalog: "cat-a-log".to_string(),
400 timezone: "+01:00".to_string(),
401 ..Default::default()
402 };
403 let query_context = create_query_context(
404 Channel::Unknown,
405 Some(&header),
406 vec![
407 ("auto_create_table".to_string(), "true".to_string()),
408 ("read_preference".to_string(), "leader".to_string()),
409 ],
410 )
411 .unwrap();
412 assert_eq!(query_context.current_catalog(), "cat-a-log");
413 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
414 assert_eq!(
415 query_context.timezone(),
416 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
417 );
418 assert!(matches!(
419 query_context.read_preference(),
420 ReadPreference::Leader
421 ));
422 assert_eq!(
423 query_context.extensions().into_iter().collect::<Vec<_>>(),
424 vec![("auto_create_table".to_string(), "true".to_string())]
425 );
426 }
427}