1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::auth_header::AuthScheme;
22use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
23use auth::{Identity, Password, UserInfoRef, UserProviderRef};
24use base64::prelude::BASE64_STANDARD;
25use base64::Engine;
26use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
27use common_catalog::parse_catalog_and_schema_from_db_string;
28use common_error::ext::ErrorExt;
29use common_error::status_code::StatusCode;
30use common_grpc::flight::do_put::DoPutResponse;
31use common_grpc::flight::FlightDecoder;
32use common_query::Output;
33use common_runtime::runtime::RuntimeTrait;
34use common_runtime::Runtime;
35use common_session::ReadPreference;
36use common_telemetry::tracing_context::{FutureExt, TracingContext};
37use common_telemetry::{debug, error, tracing, warn};
38use common_time::timezone::parse_timezone;
39use futures_util::StreamExt;
40use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
41use snafu::{OptionExt, ResultExt};
42use table::metadata::TableId;
43use tokio::sync::mpsc;
44
45use crate::error::Error::UnsupportedAuthScheme;
46use crate::error::{
47 AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
48 JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result, UnknownHintSnafu,
49};
50use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
51use crate::grpc::TonicResult;
52use crate::hint_headers::READ_PREFERENCE_HINT;
53use crate::metrics;
54use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
55use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
56
57#[derive(Clone)]
58pub struct GreptimeRequestHandler {
59 handler: ServerGrpcQueryHandlerRef,
60 user_provider: Option<UserProviderRef>,
61 runtime: Option<Runtime>,
62}
63
64impl GreptimeRequestHandler {
65 pub fn new(
66 handler: ServerGrpcQueryHandlerRef,
67 user_provider: Option<UserProviderRef>,
68 runtime: Option<Runtime>,
69 ) -> Self {
70 Self {
71 handler,
72 user_provider,
73 runtime,
74 }
75 }
76
77 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
78 pub(crate) async fn handle_request(
79 &self,
80 request: GreptimeRequest,
81 hints: Vec<(String, String)>,
82 ) -> Result<Output> {
83 let query = request.request.context(InvalidQuerySnafu {
84 reason: "Expecting non-empty GreptimeRequest.",
85 })?;
86
87 let header = request.header.as_ref();
88 let query_ctx = create_query_context(header, hints)?;
89 let user_info = auth(self.user_provider.clone(), header, &query_ctx).await?;
90 query_ctx.set_current_user(user_info);
91
92 let handler = self.handler.clone();
93 let request_type = request_type(&query).to_string();
94 let db = query_ctx.get_db_string();
95 let timer = RequestTimer::new(db.clone(), request_type);
96 let tracing_context = TracingContext::from_current_span();
97
98 let result_future = async move {
99 handler
100 .do_query(query, query_ctx)
101 .trace(tracing_context.attach(tracing::info_span!(
102 "GreptimeRequestHandler::handle_request_runtime"
103 )))
104 .await
105 .map_err(|e| {
106 if e.status_code().should_log_error() {
107 let root_error = e.root_cause().unwrap_or(&e);
108 error!(e; "Failed to handle request, error: {}", root_error.to_string());
109 } else {
110 debug!("Failed to handle request, err: {:?}", e);
112 }
113 e
114 })
115 };
116
117 match &self.runtime {
118 Some(runtime) => {
119 runtime
127 .spawn(result_future)
128 .await
129 .context(JoinTaskSnafu)
130 .inspect_err(|e| {
131 timer.record(e.status_code());
132 })?
133 }
134 None => result_future.await,
135 }
136 }
137
138 pub(crate) async fn put_record_batches(
139 &self,
140 mut stream: PutRecordBatchRequestStream,
141 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
142 ) {
143 let handler = self.handler.clone();
144 let runtime = self
145 .runtime
146 .clone()
147 .unwrap_or_else(common_runtime::global_runtime);
148 runtime.spawn(async move {
149 let mut table_id: Option<TableId> = None;
151
152 let mut decoder = FlightDecoder::default();
153 while let Some(request) = stream.next().await {
154 let request = match request {
155 Ok(request) => request,
156 Err(e) => {
157 let _ = result_sender.try_send(Err(e));
158 break;
159 }
160 };
161 let PutRecordBatchRequest {
162 table_name,
163 request_id,
164 data,
165 } = request;
166
167 let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
168 let result = handler
169 .put_record_batch(&table_name, &mut table_id, &mut decoder, data)
170 .await;
171 timer.observe_duration();
172 let result = result
173 .map(|x| DoPutResponse::new(request_id, x))
174 .map_err(Into::into);
175 if result_sender.try_send(result).is_err() {
176 warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
177 break;
178 }
179 }
180 });
181 }
182
183 pub(crate) async fn validate_auth(
184 &self,
185 username_and_password: Option<&str>,
186 db: Option<&str>,
187 ) -> Result<bool> {
188 if self.user_provider.is_none() {
189 return Ok(true);
190 }
191
192 let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
193 let username_and_password = BASE64_STANDARD
194 .decode(username_and_password)
195 .context(InvalidBase64ValueSnafu)
196 .and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
197
198 let mut split = username_and_password.splitn(2, ':');
199 let (username, password) = match (split.next(), split.next()) {
200 (Some(username), Some(password)) => (username, password),
201 (Some(username), None) => (username, ""),
202 (None, None) => return Ok(false),
203 _ => unreachable!(), };
205
206 let (catalog, schema) = if let Some(db) = db {
207 parse_catalog_and_schema_from_db_string(db)
208 } else {
209 (
210 DEFAULT_CATALOG_NAME.to_string(),
211 DEFAULT_SCHEMA_NAME.to_string(),
212 )
213 };
214 let header = RequestHeader {
215 authorization: Some(AuthHeader {
216 auth_scheme: Some(AuthScheme::Basic(Basic {
217 username: username.to_string(),
218 password: password.to_string(),
219 })),
220 }),
221 catalog,
222 schema,
223 ..Default::default()
224 };
225
226 Ok(auth(
227 self.user_provider.clone(),
228 Some(&header),
229 &QueryContext::arc(),
230 )
231 .await
232 .is_ok())
233 }
234}
235
236pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
237 request
238 .request
239 .as_ref()
240 .map(request_type)
241 .unwrap_or_default()
242}
243
244pub(crate) async fn auth(
245 user_provider: Option<UserProviderRef>,
246 header: Option<&RequestHeader>,
247 query_ctx: &QueryContextRef,
248) -> Result<UserInfoRef> {
249 let Some(user_provider) = user_provider else {
250 return Ok(auth::userinfo_by_name(None));
251 };
252
253 let auth_scheme = header
254 .and_then(|header| {
255 header
256 .authorization
257 .as_ref()
258 .and_then(|x| x.auth_scheme.clone())
259 })
260 .context(NotFoundAuthHeaderSnafu)?;
261
262 match auth_scheme {
263 AuthScheme::Basic(Basic { username, password }) => user_provider
264 .auth(
265 Identity::UserId(&username, None),
266 Password::PlainText(password.into()),
267 query_ctx.current_catalog(),
268 &query_ctx.current_schema(),
269 )
270 .await
271 .context(AuthSnafu),
272 AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
273 name: "Token AuthScheme".to_string(),
274 }),
275 }
276 .inspect_err(|e| {
277 METRIC_AUTH_FAILURE
278 .with_label_values(&[e.status_code().as_ref()])
279 .inc();
280 })
281}
282
283pub(crate) fn create_query_context(
284 header: Option<&RequestHeader>,
285 mut extensions: Vec<(String, String)>,
286) -> Result<QueryContextRef> {
287 let (catalog, schema) = header
288 .map(|header| {
289 if !header.dbname.is_empty() {
292 parse_catalog_and_schema_from_db_string(&header.dbname)
293 } else {
294 (
295 if !header.catalog.is_empty() {
296 header.catalog.to_lowercase()
297 } else {
298 DEFAULT_CATALOG_NAME.to_string()
299 },
300 if !header.schema.is_empty() {
301 header.schema.to_lowercase()
302 } else {
303 DEFAULT_SCHEMA_NAME.to_string()
304 },
305 )
306 }
307 })
308 .unwrap_or_else(|| {
309 (
310 DEFAULT_CATALOG_NAME.to_string(),
311 DEFAULT_SCHEMA_NAME.to_string(),
312 )
313 });
314 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
315 let mut ctx_builder = QueryContextBuilder::default()
316 .current_catalog(catalog)
317 .current_schema(schema)
318 .timezone(timezone);
319
320 if let Some(x) = extensions
321 .iter()
322 .position(|(k, _)| k == READ_PREFERENCE_HINT)
323 {
324 let (k, v) = extensions.swap_remove(x);
325 let Ok(read_preference) = ReadPreference::from_str(&v) else {
326 return UnknownHintSnafu {
327 hint: format!("{k}={v}"),
328 }
329 .fail();
330 };
331 ctx_builder = ctx_builder.read_preference(read_preference);
332 }
333
334 for (key, value) in extensions {
335 ctx_builder = ctx_builder.set_extension(key, value);
336 }
337 Ok(ctx_builder.build().into())
338}
339
340pub(crate) struct RequestTimer {
344 start: Instant,
345 db: String,
346 request_type: String,
347 status_code: StatusCode,
348}
349
350impl RequestTimer {
351 pub fn new(db: String, request_type: String) -> RequestTimer {
353 RequestTimer {
354 start: Instant::now(),
355 db,
356 request_type,
357 status_code: StatusCode::Success,
358 }
359 }
360
361 pub fn record(mut self, status_code: StatusCode) {
363 self.status_code = status_code;
364 }
365}
366
367impl Drop for RequestTimer {
368 fn drop(&mut self) {
369 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
370 .with_label_values(&[
371 self.db.as_str(),
372 self.request_type.as_str(),
373 self.status_code.as_ref(),
374 ])
375 .observe(self.start.elapsed().as_secs_f64());
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use chrono::FixedOffset;
382 use common_time::Timezone;
383
384 use super::*;
385
386 #[test]
387 fn test_create_query_context() {
388 let header = RequestHeader {
389 catalog: "cat-a-log".to_string(),
390 timezone: "+01:00".to_string(),
391 ..Default::default()
392 };
393 let query_context = create_query_context(
394 Some(&header),
395 vec![
396 ("auto_create_table".to_string(), "true".to_string()),
397 ("read_preference".to_string(), "leader".to_string()),
398 ],
399 )
400 .unwrap();
401 assert_eq!(query_context.current_catalog(), "cat-a-log");
402 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
403 assert_eq!(
404 query_context.timezone(),
405 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
406 );
407 assert!(matches!(
408 query_context.read_preference(),
409 ReadPreference::Leader
410 ));
411 assert_eq!(
412 query_context.extensions().into_iter().collect::<Vec<_>>(),
413 vec![("auto_create_table".to_string(), "true".to_string())]
414 );
415 }
416}