1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use session::hints::READ_PREFERENCE_HINT;
42use snafu::{OptionExt, ResultExt};
43use table::metadata::TableId;
44use tokio::sync::mpsc;
45
46use crate::error::Error::UnsupportedAuthScheme;
47use crate::error::{
48 AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
49 JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
50};
51use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
52use crate::grpc::TonicResult;
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59 handler: ServerGrpcQueryHandlerRef,
60 user_provider: Option<UserProviderRef>,
61 runtime: Option<Runtime>,
62}
63
64impl GreptimeRequestHandler {
65 pub fn new(
66 handler: ServerGrpcQueryHandlerRef,
67 user_provider: Option<UserProviderRef>,
68 runtime: Option<Runtime>,
69 ) -> Self {
70 Self {
71 handler,
72 user_provider,
73 runtime,
74 }
75 }
76
77 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
78 pub(crate) async fn handle_request(
79 &self,
80 request: GreptimeRequest,
81 hints: Vec<(String, String)>,
82 ) -> Result<Output> {
83 let query = request.request.context(InvalidQuerySnafu {
84 reason: "Expecting non-empty GreptimeRequest.",
85 })?;
86
87 let header = request.header.as_ref();
88 let query_ctx = create_query_context(header, hints)?;
89 let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
90 query_ctx.set_current_user(user_info);
91
92 let handler = self.handler.clone();
93 let request_type = request_type(&query).to_string();
94 let db = query_ctx.get_db_string();
95 let timer = RequestTimer::new(db.clone(), request_type);
96 let tracing_context = TracingContext::from_current_span();
97
98 let result_future = async move {
99 handler
100 .do_query(query, query_ctx)
101 .trace(tracing_context.attach(tracing::info_span!(
102 "GreptimeRequestHandler::handle_request_runtime"
103 )))
104 .await
105 .map_err(|e| {
106 if e.status_code().should_log_error() {
107 let root_error = e.root_cause().unwrap_or(&e);
108 error!(e; "Failed to handle request, error: {}", root_error.to_string());
109 } else {
110 debug!("Failed to handle request, err: {:?}", e);
112 }
113 e
114 })
115 };
116
117 match &self.runtime {
118 Some(runtime) => {
119 runtime
127 .spawn(result_future)
128 .await
129 .context(JoinTaskSnafu)
130 .inspect_err(|e| {
131 timer.record(e.status_code());
132 })?
133 }
134 None => result_future.await,
135 }
136 }
137
138 pub(crate) async fn put_record_batches(
139 &self,
140 mut stream: PutRecordBatchRequestStream,
141 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
142 ) {
143 let handler = self.handler.clone();
144 let runtime = self
145 .runtime
146 .clone()
147 .unwrap_or_else(common_runtime::global_runtime);
148 runtime.spawn(async move {
149 let mut table_id: Option<TableId> = None;
151
152 let mut decoder = FlightDecoder::default();
153 while let Some(request) = stream.next().await {
154 let request = match request {
155 Ok(request) => request,
156 Err(e) => {
157 let _ = result_sender.try_send(Err(e));
158 break;
159 }
160 };
161 let PutRecordBatchRequest {
162 table_name,
163 request_id,
164 data,
165 } = request;
166
167 let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
168 let result = handler
169 .put_record_batch(&table_name, &mut table_id, &mut decoder, data)
170 .await
171 .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
172 timer.observe_duration();
173 let result = result
174 .map(|x| DoPutResponse::new(request_id, x))
175 .map_err(Into::into);
176 if result_sender.try_send(result).is_err() {
177 warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
178 break;
179 }
180 }
181 });
182 }
183
184 pub(crate) async fn validate_auth(
185 &self,
186 username_and_password: Option<&str>,
187 db: Option<&str>,
188 ) -> Result<bool> {
189 if self.user_provider.is_none() {
190 return Ok(true);
191 }
192
193 let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
194 let username_and_password = BASE64_STANDARD
195 .decode(username_and_password)
196 .context(InvalidBase64ValueSnafu)
197 .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
198
199 let mut split = username_and_password.splitn(2, ':');
200 let (username, password) = match (split.next(), split.next()) {
201 (Some(username), Some(password)) => (username, password),
202 (Some(username), None) => (username, ""),
203 (None, None) => return Ok(false),
204 _ => unreachable!(), };
206
207 let (catalog, schema) = if let Some(db) = db {
208 parse_catalog_and_schema_from_db_string(db)
209 } else {
210 (
211 DEFAULT_CATALOG_NAME.to_string(),
212 DEFAULT_SCHEMA_NAME.to_string(),
213 )
214 };
215 let header = RequestHeader {
216 authorization: Some(AuthHeader {
217 auth_scheme: Some(AuthScheme::Basic(Basic {
218 username: username.to_string(),
219 password: password.to_string(),
220 })),
221 }),
222 catalog,
223 schema,
224 ..Default::default()
225 };
226
227 Ok(auth(
228 self.user_provider.clone(),
229 Some(&header),
230 &QueryContext::arc(),
231 )
232 .await
233 .is_ok())
234 }
235}
236
237pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
238 request
239 .request
240 .as_ref()
241 .map(request_type)
242 .unwrap_or_default()
243}
244
245pub(crate) async fn auth(
246 user_provider: Option<UserProviderRef>,
247 header: Option<&RequestHeader>,
248 query_ctx: &QueryContextRef,
249) -> Result<UserInfoRef> {
250 let Some(user_provider) = user_provider else {
251 return Ok(auth::userinfo_by_name(None));
252 };
253
254 let auth_scheme = header
255 .and_then(|header| {
256 header
257 .authorization
258 .as_ref()
259 .and_then(|x| x.auth_scheme.clone())
260 })
261 .context(NotFoundAuthHeaderSnafu)?;
262
263 match auth_scheme {
264 AuthScheme::Basic(Basic { username, password }) => user_provider
265 .auth(
266 Identity::UserId(&username, None),
267 Password::PlainText(password.into()),
268 query_ctx.current_catalog(),
269 &query_ctx.current_schema(),
270 )
271 .await
272 .context(AuthSnafu),
273 AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
274 name: "Token AuthScheme".to_string(),
275 }),
276 }
277 .inspect_err(|e| {
278 METRIC_AUTH_FAILURE
279 .with_label_values(&[e.status_code().as_ref()])
280 .inc();
281 })
282}
283
284pub(crate) fn create_query_context(
285 header: Option<&RequestHeader>,
286 mut extensions: Vec<(String, String)>,
287) -> Result<QueryContextRef> {
288 let (catalog, schema) = header
289 .map(|header| {
290 if !header.dbname.is_empty() {
293 parse_catalog_and_schema_from_db_string(&header.dbname)
294 } else {
295 (
296 if !header.catalog.is_empty() {
297 header.catalog.to_lowercase()
298 } else {
299 DEFAULT_CATALOG_NAME.to_string()
300 },
301 if !header.schema.is_empty() {
302 header.schema.to_lowercase()
303 } else {
304 DEFAULT_SCHEMA_NAME.to_string()
305 },
306 )
307 }
308 })
309 .unwrap_or_else(|| {
310 (
311 DEFAULT_CATALOG_NAME.to_string(),
312 DEFAULT_SCHEMA_NAME.to_string(),
313 )
314 });
315 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
316 let mut ctx_builder = QueryContextBuilder::default()
317 .current_catalog(catalog)
318 .current_schema(schema)
319 .timezone(timezone);
320
321 if let Some(x) = extensions
322 .iter()
323 .position(|(k, _)| k == READ_PREFERENCE_HINT)
324 {
325 let (k, v) = extensions.swap_remove(x);
326 let Ok(read_preference) = ReadPreference::from_str(&v) else {
327 return UnknownHintSnafu {
328 hint: format!("{k}={v}"),
329 }
330 .fail();
331 };
332 ctx_builder = ctx_builder.read_preference(read_preference);
333 }
334
335 for (key, value) in extensions {
336 ctx_builder = ctx_builder.set_extension(key, value);
337 }
338 Ok(ctx_builder.build().into())
339}
340
341pub(crate) struct RequestTimer {
345 start: Instant,
346 db: String,
347 request_type: String,
348 status_code: StatusCode,
349}
350
351impl RequestTimer {
352 pub fn new(db: String, request_type: String) -> RequestTimer {
354 RequestTimer {
355 start: Instant::now(),
356 db,
357 request_type,
358 status_code: StatusCode::Success,
359 }
360 }
361
362 pub fn record(mut self, status_code: StatusCode) {
364 self.status_code = status_code;
365 }
366}
367
368impl Drop for RequestTimer {
369 fn drop(&mut self) {
370 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
371 .with_label_values(&[
372 self.db.as_str(),
373 self.request_type.as_str(),
374 self.status_code.as_ref(),
375 ])
376 .observe(self.start.elapsed().as_secs_f64());
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use chrono::FixedOffset;
383 use common_time::Timezone;
384
385 use super::*;
386
387 #[test]
388 fn test_create_query_context() {
389 let header = RequestHeader {
390 catalog: "cat-a-log".to_string(),
391 timezone: "+01:00".to_string(),
392 ..Default::default()
393 };
394 let query_context = create_query_context(
395 Some(&header),
396 vec![
397 ("auto_create_table".to_string(), "true".to_string()),
398 ("read_preference".to_string(), "leader".to_string()),
399 ],
400 )
401 .unwrap();
402 assert_eq!(query_context.current_catalog(), "cat-a-log");
403 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
404 assert_eq!(
405 query_context.timezone(),
406 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
407 );
408 assert!(matches!(
409 query_context.read_preference(),
410 ReadPreference::Leader
411 ));
412 assert_eq!(
413 query_context.extensions().into_iter().collect::<Vec<_>>(),
414 vec![("auto_create_table".to_string(), "true".to_string())]
415 );
416 }
417}