1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::{GreptimeRequest, RequestHeader};
22use auth::UserProviderRef;
23use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
24use common_catalog::parse_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_error::status_code::StatusCode;
27use common_grpc::flight::FlightDecoder;
28use common_grpc::flight::do_put::DoPutResponse;
29use common_query::Output;
30use common_runtime::Runtime;
31use common_runtime::runtime::RuntimeTrait;
32use common_session::ReadPreference;
33use common_telemetry::tracing_context::{FutureExt, TracingContext};
34use common_telemetry::{debug, error, tracing, warn};
35use common_time::timezone::parse_timezone;
36use futures_util::StreamExt;
37use session::context::{Channel, QueryContextBuilder, QueryContextRef};
38use session::hints::READ_PREFERENCE_HINT;
39use snafu::{OptionExt, ResultExt};
40use table::TableRef;
41use tokio::sync::mpsc;
42use tokio::sync::mpsc::error::TrySendError;
43
44use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
45use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
46use crate::grpc::{FlightCompression, TonicResult, context_auth};
47use crate::metrics;
48use crate::metrics::METRIC_SERVER_GRPC_DB_REQUEST_TIMER;
49use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
50
51#[derive(Clone)]
52pub struct GreptimeRequestHandler {
53 handler: ServerGrpcQueryHandlerRef,
54 pub(crate) user_provider: Option<UserProviderRef>,
55 runtime: Option<Runtime>,
56 pub(crate) flight_compression: FlightCompression,
57}
58
59impl GreptimeRequestHandler {
60 pub fn new(
61 handler: ServerGrpcQueryHandlerRef,
62 user_provider: Option<UserProviderRef>,
63 runtime: Option<Runtime>,
64 flight_compression: FlightCompression,
65 ) -> Self {
66 Self {
67 handler,
68 user_provider,
69 runtime,
70 flight_compression,
71 }
72 }
73
74 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
75 pub(crate) async fn handle_request(
76 &self,
77 request: GreptimeRequest,
78 hints: Vec<(String, String)>,
79 ) -> Result<Output> {
80 let query = request.request.context(InvalidQuerySnafu {
81 reason: "Expecting non-empty GreptimeRequest.",
82 })?;
83
84 let header = request.header.as_ref();
85 let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
86 let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
87 query_ctx.set_current_user(user_info);
88
89 let handler = self.handler.clone();
90 let request_type = request_type(&query).to_string();
91 let db = query_ctx.get_db_string();
92 let timer = RequestTimer::new(db.clone(), request_type);
93 let tracing_context = TracingContext::from_current_span();
94
95 let result_future = async move {
96 handler
97 .do_query(query, query_ctx)
98 .trace(tracing_context.attach(tracing::info_span!(
99 "GreptimeRequestHandler::handle_request_runtime"
100 )))
101 .await
102 .map_err(|e| {
103 if e.status_code().should_log_error() {
104 let root_error = e.root_cause().unwrap_or(&e);
105 error!(e; "Failed to handle request, error: {}", root_error.to_string());
106 } else {
107 debug!("Failed to handle request, err: {:?}", e);
109 }
110 e
111 })
112 };
113
114 match &self.runtime {
115 Some(runtime) => {
116 runtime
124 .spawn(result_future)
125 .await
126 .context(JoinTaskSnafu)
127 .inspect_err(|e| {
128 timer.record(e.status_code());
129 })?
130 }
131 None => result_future.await,
132 }
133 }
134
135 pub(crate) async fn put_record_batches(
136 &self,
137 mut stream: PutRecordBatchRequestStream,
138 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
139 query_ctx: QueryContextRef,
140 ) {
141 let handler = self.handler.clone();
142 let runtime = self
143 .runtime
144 .clone()
145 .unwrap_or_else(common_runtime::global_runtime);
146 runtime.spawn(async move {
147 let mut table_ref: Option<TableRef> = None;
149
150 let mut decoder = FlightDecoder::default();
151 while let Some(request) = stream.next().await {
152 let request = match request {
153 Ok(request) => request,
154 Err(e) => {
155 let _ = result_sender.try_send(Err(e));
156 break;
157 }
158 };
159 let PutRecordBatchRequest {
160 table_name,
161 request_id,
162 data,
163 } = request;
164
165 let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();
166 let result = handler
167 .put_record_batch(&table_name, &mut table_ref, &mut decoder, data, query_ctx.clone())
168 .await
169 .inspect_err(|e| error!(e; "Failed to handle flight record batches"));
170 timer.observe_duration();
171 let result = result
172 .map(|x| DoPutResponse::new(request_id, x))
173 .map_err(Into::into);
174 if let Err(e)= result_sender.try_send(result)
175 && let TrySendError::Closed(_) = e {
176 warn!(r#""DoPut" client with request_id {} maybe unreachable, abort handling its message"#, request_id);
177 break;
178 }
179 }
180 });
181 }
182}
183
184pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
185 request
186 .request
187 .as_ref()
188 .map(request_type)
189 .unwrap_or_default()
190}
191
192pub(crate) fn create_query_context(
195 channel: Channel,
196 header: Option<&RequestHeader>,
197 mut extensions: Vec<(String, String)>,
198) -> Result<QueryContextRef> {
199 let (catalog, schema) = header
200 .map(|header| {
201 if !header.dbname.is_empty() {
204 parse_catalog_and_schema_from_db_string(&header.dbname)
205 } else {
206 (
207 if !header.catalog.is_empty() {
208 header.catalog.to_lowercase()
209 } else {
210 DEFAULT_CATALOG_NAME.to_string()
211 },
212 if !header.schema.is_empty() {
213 header.schema.to_lowercase()
214 } else {
215 DEFAULT_SCHEMA_NAME.to_string()
216 },
217 )
218 }
219 })
220 .unwrap_or_else(|| {
221 (
222 DEFAULT_CATALOG_NAME.to_string(),
223 DEFAULT_SCHEMA_NAME.to_string(),
224 )
225 });
226 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
227 let mut ctx_builder = QueryContextBuilder::default()
228 .current_catalog(catalog)
229 .current_schema(schema)
230 .timezone(timezone)
231 .channel(channel);
232
233 if let Some(x) = extensions
234 .iter()
235 .position(|(k, _)| k == READ_PREFERENCE_HINT)
236 {
237 let (k, v) = extensions.swap_remove(x);
238 let Ok(read_preference) = ReadPreference::from_str(&v) else {
239 return UnknownHintSnafu {
240 hint: format!("{k}={v}"),
241 }
242 .fail();
243 };
244 ctx_builder = ctx_builder.read_preference(read_preference);
245 }
246
247 for (key, value) in extensions {
248 ctx_builder = ctx_builder.set_extension(key, value);
249 }
250 Ok(ctx_builder.build().into())
251}
252
253pub(crate) struct RequestTimer {
257 start: Instant,
258 db: String,
259 request_type: String,
260 status_code: StatusCode,
261}
262
263impl RequestTimer {
264 pub fn new(db: String, request_type: String) -> RequestTimer {
266 RequestTimer {
267 start: Instant::now(),
268 db,
269 request_type,
270 status_code: StatusCode::Success,
271 }
272 }
273
274 pub fn record(mut self, status_code: StatusCode) {
276 self.status_code = status_code;
277 }
278}
279
280impl Drop for RequestTimer {
281 fn drop(&mut self) {
282 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
283 .with_label_values(&[
284 self.db.as_str(),
285 self.request_type.as_str(),
286 self.status_code.as_ref(),
287 ])
288 .observe(self.start.elapsed().as_secs_f64());
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use chrono::FixedOffset;
295 use common_time::Timezone;
296
297 use super::*;
298
299 #[test]
300 fn test_create_query_context() {
301 let header = RequestHeader {
302 catalog: "cat-a-log".to_string(),
303 timezone: "+01:00".to_string(),
304 ..Default::default()
305 };
306 let query_context = create_query_context(
307 Channel::Unknown,
308 Some(&header),
309 vec![
310 ("auto_create_table".to_string(), "true".to_string()),
311 ("read_preference".to_string(), "leader".to_string()),
312 ],
313 )
314 .unwrap();
315 assert_eq!(query_context.current_catalog(), "cat-a-log");
316 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
317 assert_eq!(
318 query_context.timezone(),
319 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
320 );
321 assert!(matches!(
322 query_context.read_preference(),
323 ReadPreference::Leader
324 ));
325 assert_eq!(
326 query_context.extensions().into_iter().collect::<Vec<_>>(),
327 vec![("auto_create_table".to_string(), "true".to_string())]
328 );
329 }
330}