1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::TableRef;
44use tokio::sync::mpsc;
45
46use crate::error::Error::UnsupportedAuthScheme;
47use crate::error::{
48 AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
49 JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
50};
51use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
52use crate::grpc::{FlightCompression, TonicResult};
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59 handler: ServerGrpcQueryHandlerRef,
60 user_provider: Option<UserProviderRef>,
61 runtime: Option<Runtime>,
62 pub(crate) flight_compression: FlightCompression,
63}
64
65impl GreptimeRequestHandler {
66 pub fn new(
67 handler: ServerGrpcQueryHandlerRef,
68 user_provider: Option<UserProviderRef>,
69 runtime: Option<Runtime>,
70 flight_compression: FlightCompression,
71 ) -> Self {
72 Self {
73 handler,
74 user_provider,
75 runtime,
76 flight_compression,
77 }
78 }
79
80 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
81 pub(crate) async fn handle_request(
82 &self,
83 request: GreptimeRequest,
84 hints: Vec<(String, String)>,
85 ) -> Result<Output> {
86 let query = request.request.context(InvalidQuerySnafu {
87 reason: "Expecting non-empty GreptimeRequest.",
88 })?;
89
90 let header = request.header.as_ref();
91 let query_ctx = create_query_context(header, hints)?;
92 let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
93 query_ctx.set_current_user(user_info);
94
95 let handler = self.handler.clone();
96 let request_type = request_type(&query).to_string();
97 let db = query_ctx.get_db_string();
98 let timer = RequestTimer::new(db.clone(), request_type);
99 let tracing_context = TracingContext::from_current_span();
100
101 let result_future = async move {
102 handler
103 .do_query(query, query_ctx)
104 .trace(tracing_context.attach(tracing::info_span!(
105 "GreptimeRequestHandler::handle_request_runtime"
106 )))
107 .await
108 .map_err(|e| {
109 if e.status_code().should_log_error() {
110 let root_error = e.root_cause().unwrap_or(&e);
111 error!(e; "Failed to handle request, error: {}", root_error.to_string());
112 } else {
113 debug!("Failed to handle request, err: {:?}", e);
115 }
116 e
117 })
118 };
119
120 match &self.runtime {
121 Some(runtime) => {
122 runtime
130 .spawn(result_future)
131 .await
132 .context(JoinTaskSnafu)
133 .inspect_err(|e| {
134 timer.record(e.status_code());
135 })?
136 }
137 None => result_future.await,
138 }
139 }
140
141 pub(crate) async fn put_record_batches(
142 &self,
143 mut stream: PutRecordBatchRequestStream,
144 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
145 ) {
146 let handler = self.handler.clone();
147 let runtime = self
148 .runtime
149 .clone()
150 .unwrap_or_else(common_runtime::global_runtime);
151 runtime.spawn(async move {
152 let mut table_ref: Option<TableRef> = None;
154
155 let mut decoder = FlightDecoder::default();
156 while let Some(request) = stream.next().await {
157 let request = match request {
158 Ok(request) => request,
159 Err(e) => {
160 let _ = result_sender.try_send(Err(e));
161 break;
162 }
163 };
164 let PutRecordBatchRequest {
165 table_name,
166 request_id,
167 data,
168 } = request;
169
170 let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
171 let result = handler
172 .put_record_batch(&table_name, &mut table_ref, &mut decoder, data)
173 .await
174 .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
175 timer.observe_duration();
176 let result = result
177 .map(|x| DoPutResponse::new(request_id, x))
178 .map_err(Into::into);
179 if result_sender.try_send(result).is_err() {
180 warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
181 break;
182 }
183 }
184 });
185 }
186
187 pub(crate) async fn validate_auth(
188 &self,
189 username_and_password: Option<&str>,
190 db: Option<&str>,
191 ) -> Result<bool> {
192 if self.user_provider.is_none() {
193 return Ok(true);
194 }
195
196 let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
197 let username_and_password = BASE64_STANDARD
198 .decode(username_and_password)
199 .context(InvalidBase64ValueSnafu)
200 .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
201
202 let mut split = username_and_password.splitn(2, ':');
203 let (username, password) = match (split.next(), split.next()) {
204 (Some(username), Some(password)) => (username, password),
205 (Some(username), None) => (username, ""),
206 (None, None) => return Ok(false),
207 _ => unreachable!(), };
209
210 let (catalog, schema) = if let Some(db) = db {
211 parse_catalog_and_schema_from_db_string(db)
212 } else {
213 (
214 DEFAULT_CATALOG_NAME.to_string(),
215 DEFAULT_SCHEMA_NAME.to_string(),
216 )
217 };
218 let header = RequestHeader {
219 authorization: Some(AuthHeader {
220 auth_scheme: Some(AuthScheme::Basic(Basic {
221 username: username.to_string(),
222 password: password.to_string(),
223 })),
224 }),
225 catalog,
226 schema,
227 ..Default::default()
228 };
229
230 Ok(auth(
231 self.user_provider.clone(),
232 Some(&header),
233 &QueryContext::arc(),
234 )
235 .await
236 .is_ok())
237 }
238}
239
240pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
241 request
242 .request
243 .as_ref()
244 .map(request_type)
245 .unwrap_or_default()
246}
247
248pub(crate) async fn auth(
249 user_provider: Option<UserProviderRef>,
250 header: Option<&RequestHeader>,
251 query_ctx: &QueryContextRef,
252) -> Result<UserInfoRef> {
253 let Some(user_provider) = user_provider else {
254 return Ok(auth::userinfo_by_name(None));
255 };
256
257 let auth_scheme = header
258 .and_then(|header| {
259 header
260 .authorization
261 .as_ref()
262 .and_then(|x| x.auth_scheme.clone())
263 })
264 .context(NotFoundAuthHeaderSnafu)?;
265
266 match auth_scheme {
267 AuthScheme::Basic(Basic { username, password }) => user_provider
268 .auth(
269 Identity::UserId(&username, None),
270 Password::PlainText(password.into()),
271 query_ctx.current_catalog(),
272 &query_ctx.current_schema(),
273 )
274 .await
275 .context(AuthSnafu),
276 AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
277 name: "Token AuthScheme".to_string(),
278 }),
279 }
280 .inspect_err(|e| {
281 METRIC_AUTH_FAILURE
282 .with_label_values(&[e.status_code().as_ref()])
283 .inc();
284 })
285}
286
287pub(crate) fn create_query_context(
288 header: Option<&RequestHeader>,
289 mut extensions: Vec<(String, String)>,
290) -> Result<QueryContextRef> {
291 let (catalog, schema) = header
292 .map(|header| {
293 if !header.dbname.is_empty() {
296 parse_catalog_and_schema_from_db_string(&header.dbname)
297 } else {
298 (
299 if !header.catalog.is_empty() {
300 header.catalog.to_lowercase()
301 } else {
302 DEFAULT_CATALOG_NAME.to_string()
303 },
304 if !header.schema.is_empty() {
305 header.schema.to_lowercase()
306 } else {
307 DEFAULT_SCHEMA_NAME.to_string()
308 },
309 )
310 }
311 })
312 .unwrap_or_else(|| {
313 (
314 DEFAULT_CATALOG_NAME.to_string(),
315 DEFAULT_SCHEMA_NAME.to_string(),
316 )
317 });
318 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
319 let mut ctx_builder = QueryContextBuilder::default()
320 .current_catalog(catalog)
321 .current_schema(schema)
322 .timezone(timezone);
323
324 if let Some(x) = extensions
325 .iter()
326 .position(|(k, _)| k == READ_PREFERENCE_HINT)
327 {
328 let (k, v) = extensions.swap_remove(x);
329 let Ok(read_preference) = ReadPreference::from_str(&v) else {
330 return UnknownHintSnafu {
331 hint: format!("{k}={v}"),
332 }
333 .fail();
334 };
335 ctx_builder = ctx_builder.read_preference(read_preference);
336 }
337
338 for (key, value) in extensions {
339 ctx_builder = ctx_builder.set_extension(key, value);
340 }
341 Ok(ctx_builder.build().into())
342}
343
344pub(crate) struct RequestTimer {
348 start: Instant,
349 db: String,
350 request_type: String,
351 status_code: StatusCode,
352}
353
354impl RequestTimer {
355 pub fn new(db: String, request_type: String) -> RequestTimer {
357 RequestTimer {
358 start: Instant::now(),
359 db,
360 request_type,
361 status_code: StatusCode::Success,
362 }
363 }
364
365 pub fn record(mut self, status_code: StatusCode) {
367 self.status_code = status_code;
368 }
369}
370
371impl Drop for RequestTimer {
372 fn drop(&mut self) {
373 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
374 .with_label_values(&[
375 self.db.as_str(),
376 self.request_type.as_str(),
377 self.status_code.as_ref(),
378 ])
379 .observe(self.start.elapsed().as_secs_f64());
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use chrono::FixedOffset;
386 use common_time::Timezone;
387
388 use super::*;
389
390 #[test]
391 fn test_create_query_context() {
392 let header = RequestHeader {
393 catalog: "cat-a-log".to_string(),
394 timezone: "+01:00".to_string(),
395 ..Default::default()
396 };
397 let query_context = create_query_context(
398 Some(&header),
399 vec![
400 ("auto_create_table".to_string(), "true".to_string()),
401 ("read_preference".to_string(), "leader".to_string()),
402 ],
403 )
404 .unwrap();
405 assert_eq!(query_context.current_catalog(), "cat-a-log");
406 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
407 assert_eq!(
408 query_context.timezone(),
409 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
410 );
411 assert!(matches!(
412 query_context.read_preference(),
413 ReadPreference::Leader
414 ));
415 assert_eq!(
416 query_context.extensions().into_iter().collect::<Vec<_>>(),
417 vec![("auto_create_table".to_string(), "true".to_string())]
418 );
419 }
420}