1use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::RegionRequest;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::BoxedError;
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{OptionExt, ResultExt, location};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41 self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42 IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{Client, Error, metrics};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48 client: Client,
49 send_compression: bool,
50 accept_compression: bool,
51}
52
53#[async_trait]
54impl Datanode for RegionRequester {
55 async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
56 self.handle_inner(request).await.map_err(|err| {
57 if err.should_retry() {
58 meta_error::Error::RetryLater {
59 source: BoxedError::new(err),
60 clean_poisons: false,
61 }
62 } else {
63 meta_error::Error::External {
64 source: BoxedError::new(err),
65 location: location!(),
66 }
67 }
68 })
69 }
70
71 async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
72 let plan = DFLogicalSubstraitConvertor
73 .encode(&request.plan, DefaultSerializer)
74 .map_err(BoxedError::new)
75 .context(meta_error::ExternalSnafu)?
76 .to_vec();
77 let request = api::v1::region::QueryRequest {
78 header: request.header,
79 region_id: request.region_id.as_u64(),
80 plan,
81 };
82
83 let ticket = Ticket {
84 ticket: request.encode_to_vec().into(),
85 };
86 self.do_get_inner(ticket)
87 .await
88 .map_err(BoxedError::new)
89 .context(meta_error::ExternalSnafu)
90 }
91}
92
93impl RegionRequester {
94 pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
95 Self {
96 client,
97 send_compression,
98 accept_compression,
99 }
100 }
101
102 pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
103 let mut flight_client = self
104 .client
105 .make_flight_client(self.send_compression, self.accept_compression)?;
106 let response = flight_client
107 .mut_inner()
108 .do_get(ticket)
109 .await
110 .or_else(|e| {
111 let tonic_code = e.code();
112 let e: error::Error = e.into();
113 error!(
114 e; "Failed to do Flight get, addr: {}, code: {}",
115 flight_client.addr(),
116 tonic_code
117 );
118 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
119 addr: flight_client.addr().to_string(),
120 tonic_code,
121 })
122 })?;
123
124 let flight_data_stream = response.into_inner();
125 let mut decoder = FlightDecoder::default();
126
127 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
128 flight_data
129 .map_err(Error::from)
130 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
131 .context(IllegalFlightMessagesSnafu {
132 reason: "none message",
133 })
134 });
135
136 let Some(first_flight_message) = flight_message_stream.next().await else {
137 return IllegalFlightMessagesSnafu {
138 reason: "Expect the response not to be empty",
139 }
140 .fail();
141 };
142 let FlightMessage::Schema(schema) = first_flight_message? else {
143 return IllegalFlightMessagesSnafu {
144 reason: "Expect schema to be the first flight message",
145 }
146 .fail();
147 };
148
149 let metrics = Arc::new(ArcSwapOption::from(None));
150 let metrics_ref = metrics.clone();
151
152 let tracing_context = TracingContext::from_current_span();
153
154 let schema = Arc::new(
155 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
156 );
157 let schema_cloned = schema.clone();
158 let stream = Box::pin(stream!({
159 let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
160 "poll_flight_data_stream"
161 ));
162
163 let mut buffered_message: Option<FlightMessage> = None;
164 let mut stream_ended = false;
165
166 while !stream_ended {
167 let flight_message_item = if let Some(msg) = buffered_message.take() {
169 Some(Ok(msg))
170 } else {
171 flight_message_stream.next().await
172 };
173
174 let flight_message = match flight_message_item {
175 Some(Ok(message)) => message,
176 Some(Err(e)) => {
177 yield Err(BoxedError::new(e)).context(ExternalSnafu);
178 break;
179 }
180 None => break,
181 };
182
183 match flight_message {
184 FlightMessage::RecordBatch(record_batch) => {
185 let result_to_yield =
186 RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
187
188 if let Some(next_flight_message_result) = flight_message_stream.next().await
190 {
191 match next_flight_message_result {
192 Ok(FlightMessage::Metrics(s)) => {
193 let m = serde_json::from_str(&s).ok().map(Arc::new);
194 metrics_ref.swap(m);
195 }
196 Ok(FlightMessage::RecordBatch(rb)) => {
197 buffered_message = Some(FlightMessage::RecordBatch(rb));
200 }
201 Ok(_) => {
202 yield IllegalFlightMessagesSnafu {
203 reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
204 }
205 .fail()
206 .map_err(BoxedError::new)
207 .context(ExternalSnafu);
208 break;
209 }
210 Err(e) => {
211 yield Err(BoxedError::new(e)).context(ExternalSnafu);
212 break;
213 }
214 }
215 } else {
216 stream_ended = true;
218 }
219
220 yield Ok(result_to_yield);
221 }
222 FlightMessage::Metrics(s) => {
223 let m = serde_json::from_str(&s).ok().map(Arc::new);
225 metrics_ref.swap(m);
226 break;
227 }
228 _ => {
229 yield IllegalFlightMessagesSnafu {
230 reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
231 }
232 .fail()
233 .map_err(BoxedError::new)
234 .context(ExternalSnafu);
235 break;
236 }
237 }
238 }
239 }));
240 let record_batch_stream = RecordBatchStreamWrapper {
241 schema,
242 stream,
243 output_ordering: None,
244 metrics,
245 };
246 Ok(Box::pin(record_batch_stream))
247 }
248
249 async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
250 let request_type = request
251 .body
252 .as_ref()
253 .with_context(|| MissingFieldSnafu { field: "body" })?
254 .as_ref()
255 .to_string();
256 let _timer = metrics::METRIC_REGION_REQUEST_GRPC
257 .with_label_values(&[request_type.as_str()])
258 .start_timer();
259
260 let (addr, mut client) = self.client.raw_region_client()?;
261
262 let response = client
263 .handle(request)
264 .await
265 .map_err(|e| {
266 let code = e.code();
267 error::Error::RegionServer {
269 addr,
270 code,
271 source: BoxedError::new(error::Error::from(e)),
272 location: location!(),
273 }
274 })?
275 .into_inner();
276
277 check_response_header(&response.header)?;
278
279 Ok(RegionResponse::from_region_response(response))
280 }
281
282 pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
283 self.handle_inner(request).await
284 }
285}
286
287pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
288 let status = header
289 .as_ref()
290 .and_then(|header| header.status.as_ref())
291 .context(IllegalDatabaseResponseSnafu {
292 err_msg: "either response header or status is missing",
293 })?;
294
295 if StatusCode::is_success(status.status_code) {
296 Ok(())
297 } else {
298 let code =
299 StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
300 err_msg: format!("unknown server status: {:?}", status),
301 })?;
302 ServerSnafu {
303 code,
304 msg: status.err_msg.clone(),
305 }
306 .fail()
307 }
308}
309
310#[cfg(test)]
311mod test {
312 use api::v1::Status as PbStatus;
313
314 use super::*;
315 use crate::Error::{IllegalDatabaseResponse, Server};
316
317 #[test]
318 fn test_check_response_header() {
319 let result = check_response_header(&None);
320 assert!(matches!(
321 result.unwrap_err(),
322 IllegalDatabaseResponse { .. }
323 ));
324
325 let result = check_response_header(&Some(ResponseHeader { status: None }));
326 assert!(matches!(
327 result.unwrap_err(),
328 IllegalDatabaseResponse { .. }
329 ));
330
331 let result = check_response_header(&Some(ResponseHeader {
332 status: Some(PbStatus {
333 status_code: StatusCode::Success as u32,
334 err_msg: String::default(),
335 }),
336 }));
337 assert!(result.is_ok());
338
339 let result = check_response_header(&Some(ResponseHeader {
340 status: Some(PbStatus {
341 status_code: u32::MAX,
342 err_msg: String::default(),
343 }),
344 }));
345 assert!(matches!(
346 result.unwrap_err(),
347 IllegalDatabaseResponse { .. }
348 ));
349
350 let result = check_response_header(&Some(ResponseHeader {
351 status: Some(PbStatus {
352 status_code: StatusCode::Internal as u32,
353 err_msg: "blabla".to_string(),
354 }),
355 }));
356 let Server { code, msg, .. } = result.unwrap_err() else {
357 unreachable!()
358 };
359 assert_eq!(code, StatusCode::Internal);
360 assert_eq!(msg, "blabla");
361 }
362}