1use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::RegionRequest;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::BoxedError;
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing::Span;
34use common_telemetry::tracing_context::TracingContext;
35use prost::Message;
36use query::query_engine::DefaultSerializer;
37use snafu::{OptionExt, ResultExt, location};
38use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
39use tokio_stream::StreamExt;
40
41use crate::error::{
42 self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
43 IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
44};
45use crate::{Client, Error, metrics};
46
47#[derive(Debug)]
48pub struct RegionRequester {
49 client: Client,
50 send_compression: bool,
51 accept_compression: bool,
52}
53
54#[async_trait]
55impl Datanode for RegionRequester {
56 async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
57 self.handle_inner(request).await.map_err(|err| {
58 if err.should_retry() {
59 meta_error::Error::RetryLater {
60 source: BoxedError::new(err),
61 clean_poisons: false,
62 }
63 } else {
64 meta_error::Error::External {
65 source: BoxedError::new(err),
66 location: location!(),
67 }
68 }
69 })
70 }
71
72 async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
73 let plan = DFLogicalSubstraitConvertor
74 .encode(&request.plan, DefaultSerializer)
75 .map_err(BoxedError::new)
76 .context(meta_error::ExternalSnafu)?
77 .to_vec();
78 let request = api::v1::region::QueryRequest {
79 header: request.header,
80 region_id: request.region_id.as_u64(),
81 plan,
82 };
83
84 let ticket = Ticket {
85 ticket: request.encode_to_vec().into(),
86 };
87 self.do_get_inner(ticket)
88 .await
89 .map_err(BoxedError::new)
90 .context(meta_error::ExternalSnafu)
91 }
92}
93
94impl RegionRequester {
95 pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
96 Self {
97 client,
98 send_compression,
99 accept_compression,
100 }
101 }
102
103 pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
104 let mut flight_client = self
105 .client
106 .make_flight_client(self.send_compression, self.accept_compression)?;
107 let response = flight_client
108 .mut_inner()
109 .do_get(ticket)
110 .await
111 .or_else(|e| {
112 let tonic_code = e.code();
113 let e: error::Error = e.into();
114 error!(
115 e; "Failed to do Flight get, addr: {}, code: {}",
116 flight_client.addr(),
117 tonic_code
118 );
119 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
120 addr: flight_client.addr().to_string(),
121 tonic_code,
122 })
123 })?;
124
125 let flight_data_stream = response.into_inner();
126 let mut decoder = FlightDecoder::default();
127
128 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
129 flight_data
130 .map_err(Error::from)
131 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
132 .context(IllegalFlightMessagesSnafu {
133 reason: "none message",
134 })
135 });
136
137 let Some(first_flight_message) = flight_message_stream.next().await else {
138 return IllegalFlightMessagesSnafu {
139 reason: "Expect the response not to be empty",
140 }
141 .fail();
142 };
143 let FlightMessage::Schema(schema) = first_flight_message? else {
144 return IllegalFlightMessagesSnafu {
145 reason: "Expect schema to be the first flight message",
146 }
147 .fail();
148 };
149
150 let metrics = Arc::new(ArcSwapOption::from(None));
151 let metrics_ref = metrics.clone();
152
153 let tracing_context = TracingContext::from_current_span();
154
155 let schema = Arc::new(
156 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
157 );
158 let schema_cloned = schema.clone();
159 let stream = Box::pin(stream!({
160 let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
161 "poll_flight_data_stream"
162 ));
163
164 let mut buffered_message: Option<FlightMessage> = None;
165 let mut stream_ended = false;
166
167 while !stream_ended {
168 let flight_message_item = if let Some(msg) = buffered_message.take() {
170 Some(Ok(msg))
171 } else {
172 flight_message_stream.next().await
173 };
174
175 let flight_message = match flight_message_item {
176 Some(Ok(message)) => message,
177 Some(Err(e)) => {
178 yield Err(BoxedError::new(e)).context(ExternalSnafu);
179 break;
180 }
181 None => break,
182 };
183
184 match flight_message {
185 FlightMessage::RecordBatch(record_batch) => {
186 let result_to_yield =
187 RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
188
189 if let Some(next_flight_message_result) = flight_message_stream.next().await
191 {
192 match next_flight_message_result {
193 Ok(FlightMessage::Metrics(s)) => {
194 let m = serde_json::from_str(&s).ok().map(Arc::new);
195 metrics_ref.swap(m);
196 }
197 Ok(FlightMessage::RecordBatch(rb)) => {
198 buffered_message = Some(FlightMessage::RecordBatch(rb));
201 }
202 Ok(_) => {
203 yield IllegalFlightMessagesSnafu {
204 reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
205 }
206 .fail()
207 .map_err(BoxedError::new)
208 .context(ExternalSnafu);
209 break;
210 }
211 Err(e) => {
212 yield Err(BoxedError::new(e)).context(ExternalSnafu);
213 break;
214 }
215 }
216 } else {
217 stream_ended = true;
219 }
220
221 yield Ok(result_to_yield);
222 }
223 FlightMessage::Metrics(s) => {
224 let m = serde_json::from_str(&s).ok().map(Arc::new);
226 metrics_ref.swap(m);
227 break;
228 }
229 _ => {
230 yield IllegalFlightMessagesSnafu {
231 reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
232 }
233 .fail()
234 .map_err(BoxedError::new)
235 .context(ExternalSnafu);
236 break;
237 }
238 }
239 }
240 }));
241 let record_batch_stream = RecordBatchStreamWrapper {
242 schema,
243 stream,
244 output_ordering: None,
245 metrics,
246 span: Span::current(),
247 };
248 Ok(Box::pin(record_batch_stream))
249 }
250
251 async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
252 let request_type = request
253 .body
254 .as_ref()
255 .with_context(|| MissingFieldSnafu { field: "body" })?
256 .as_ref()
257 .to_string();
258 let _timer = metrics::METRIC_REGION_REQUEST_GRPC
259 .with_label_values(&[request_type.as_str()])
260 .start_timer();
261
262 let (addr, mut client) = self.client.raw_region_client()?;
263
264 let response = client
265 .handle(request)
266 .await
267 .map_err(|e| {
268 let code = e.code();
269 error::Error::RegionServer {
271 addr,
272 code,
273 source: BoxedError::new(error::Error::from(e)),
274 location: location!(),
275 }
276 })?
277 .into_inner();
278
279 check_response_header(&response.header)?;
280
281 Ok(RegionResponse::from_region_response(response))
282 }
283
284 pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
285 self.handle_inner(request).await
286 }
287}
288
289pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
290 let status = header
291 .as_ref()
292 .and_then(|header| header.status.as_ref())
293 .context(IllegalDatabaseResponseSnafu {
294 err_msg: "either response header or status is missing",
295 })?;
296
297 if StatusCode::is_success(status.status_code) {
298 Ok(())
299 } else {
300 let code =
301 StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
302 err_msg: format!("unknown server status: {:?}", status),
303 })?;
304 ServerSnafu {
305 code,
306 msg: status.err_msg.clone(),
307 }
308 .fail()
309 }
310}
311
312#[cfg(test)]
313mod test {
314 use api::v1::Status as PbStatus;
315
316 use super::*;
317 use crate::Error::{IllegalDatabaseResponse, Server};
318
319 #[test]
320 fn test_check_response_header() {
321 let result = check_response_header(&None);
322 assert!(matches!(
323 result.unwrap_err(),
324 IllegalDatabaseResponse { .. }
325 ));
326
327 let result = check_response_header(&Some(ResponseHeader { status: None }));
328 assert!(matches!(
329 result.unwrap_err(),
330 IllegalDatabaseResponse { .. }
331 ));
332
333 let result = check_response_header(&Some(ResponseHeader {
334 status: Some(PbStatus {
335 status_code: StatusCode::Success as u32,
336 err_msg: String::default(),
337 }),
338 }));
339 assert!(result.is_ok());
340
341 let result = check_response_header(&Some(ResponseHeader {
342 status: Some(PbStatus {
343 status_code: u32::MAX,
344 err_msg: String::default(),
345 }),
346 }));
347 assert!(matches!(
348 result.unwrap_err(),
349 IllegalDatabaseResponse { .. }
350 ));
351
352 let result = check_response_header(&Some(ResponseHeader {
353 status: Some(PbStatus {
354 status_code: StatusCode::Internal as u32,
355 err_msg: "blabla".to_string(),
356 }),
357 }));
358 let Server { code, msg, .. } = result.unwrap_err() else {
359 unreachable!()
360 };
361 assert_eq!(code, StatusCode::Internal);
362 assert_eq!(msg, "blabla");
363 }
364}