1use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::region::RegionRequest;
19use api::v1::ResponseHeader;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::{BoxedError, ErrorExt};
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{location, OptionExt, ResultExt};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41 self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42 IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{metrics, Client, Error};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48 client: Client,
49 send_compression: bool,
50 accept_compression: bool,
51}
52
53#[async_trait]
54impl Datanode for RegionRequester {
55 async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
56 self.handle_inner(request).await.map_err(|err| {
57 if err.should_retry() {
58 meta_error::Error::RetryLater {
59 source: BoxedError::new(err),
60 clean_poisons: false,
61 }
62 } else {
63 meta_error::Error::External {
64 source: BoxedError::new(err),
65 location: location!(),
66 }
67 }
68 })
69 }
70
71 async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
72 let plan = DFLogicalSubstraitConvertor
73 .encode(&request.plan, DefaultSerializer)
74 .map_err(BoxedError::new)
75 .context(meta_error::ExternalSnafu)?
76 .to_vec();
77 let request = api::v1::region::QueryRequest {
78 header: request.header,
79 region_id: request.region_id.as_u64(),
80 plan,
81 };
82
83 let ticket = Ticket {
84 ticket: request.encode_to_vec().into(),
85 };
86 self.do_get_inner(ticket)
87 .await
88 .map_err(BoxedError::new)
89 .context(meta_error::ExternalSnafu)
90 }
91}
92
93impl RegionRequester {
94 pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
95 Self {
96 client,
97 send_compression,
98 accept_compression,
99 }
100 }
101
102 pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
103 let mut flight_client = self
104 .client
105 .make_flight_client(self.send_compression, self.accept_compression)?;
106 let response = flight_client
107 .mut_inner()
108 .do_get(ticket)
109 .await
110 .map_err(|e| {
111 let tonic_code = e.code();
112 let e: error::Error = e.into();
113 let code = e.status_code();
114 let msg = e.to_string();
115 let error = ServerSnafu { code, msg }
116 .fail::<()>()
117 .map_err(BoxedError::new)
118 .with_context(|_| FlightGetSnafu {
119 tonic_code,
120 addr: flight_client.addr().to_string(),
121 })
122 .unwrap_err();
123 error!(
124 e; "Failed to do Flight get, addr: {}, code: {}",
125 flight_client.addr(),
126 tonic_code
127 );
128 error
129 })?;
130
131 let flight_data_stream = response.into_inner();
132 let mut decoder = FlightDecoder::default();
133
134 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
135 flight_data
136 .map_err(Error::from)
137 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
138 });
139
140 let Some(first_flight_message) = flight_message_stream.next().await else {
141 return IllegalFlightMessagesSnafu {
142 reason: "Expect the response not to be empty",
143 }
144 .fail();
145 };
146 let FlightMessage::Schema(schema) = first_flight_message? else {
147 return IllegalFlightMessagesSnafu {
148 reason: "Expect schema to be the first flight message",
149 }
150 .fail();
151 };
152
153 let metrics = Arc::new(ArcSwapOption::from(None));
154 let metrics_ref = metrics.clone();
155
156 let tracing_context = TracingContext::from_current_span();
157
158 let schema = Arc::new(
159 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
160 );
161 let schema_cloned = schema.clone();
162 let stream = Box::pin(stream!({
163 let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
164 "poll_flight_data_stream"
165 ));
166 while let Some(flight_message) = flight_message_stream.next().await {
167 let flight_message = flight_message
168 .map_err(BoxedError::new)
169 .context(ExternalSnafu)?;
170
171 match flight_message {
172 FlightMessage::RecordBatch(record_batch) => {
173 yield RecordBatch::try_from_df_record_batch(
174 schema_cloned.clone(),
175 record_batch,
176 )
177 }
178 FlightMessage::Metrics(s) => {
179 let m = serde_json::from_str(&s).ok().map(Arc::new);
180 metrics_ref.swap(m);
181 break;
182 }
183 _ => {
184 yield IllegalFlightMessagesSnafu {
185 reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
186 }
187 .fail()
188 .map_err(BoxedError::new)
189 .context(ExternalSnafu);
190 break;
191 }
192 }
193 }
194 }));
195 let record_batch_stream = RecordBatchStreamWrapper {
196 schema,
197 stream,
198 output_ordering: None,
199 metrics,
200 };
201 Ok(Box::pin(record_batch_stream))
202 }
203
204 async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
205 let request_type = request
206 .body
207 .as_ref()
208 .with_context(|| MissingFieldSnafu { field: "body" })?
209 .as_ref()
210 .to_string();
211 let _timer = metrics::METRIC_REGION_REQUEST_GRPC
212 .with_label_values(&[request_type.as_str()])
213 .start_timer();
214
215 let (addr, mut client) = self.client.raw_region_client()?;
216
217 let response = client
218 .handle(request)
219 .await
220 .map_err(|e| {
221 let code = e.code();
222 error::Error::RegionServer {
224 addr,
225 code,
226 source: BoxedError::new(error::Error::from(e)),
227 location: location!(),
228 }
229 })?
230 .into_inner();
231
232 check_response_header(&response.header)?;
233
234 Ok(RegionResponse::from_region_response(response))
235 }
236
237 pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
238 self.handle_inner(request).await
239 }
240}
241
242pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
243 let status = header
244 .as_ref()
245 .and_then(|header| header.status.as_ref())
246 .context(IllegalDatabaseResponseSnafu {
247 err_msg: "either response header or status is missing",
248 })?;
249
250 if StatusCode::is_success(status.status_code) {
251 Ok(())
252 } else {
253 let code =
254 StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
255 err_msg: format!("unknown server status: {:?}", status),
256 })?;
257 ServerSnafu {
258 code,
259 msg: status.err_msg.clone(),
260 }
261 .fail()
262 }
263}
264
265#[cfg(test)]
266mod test {
267 use api::v1::Status as PbStatus;
268
269 use super::*;
270 use crate::Error::{IllegalDatabaseResponse, Server};
271
272 #[test]
273 fn test_check_response_header() {
274 let result = check_response_header(&None);
275 assert!(matches!(
276 result.unwrap_err(),
277 IllegalDatabaseResponse { .. }
278 ));
279
280 let result = check_response_header(&Some(ResponseHeader { status: None }));
281 assert!(matches!(
282 result.unwrap_err(),
283 IllegalDatabaseResponse { .. }
284 ));
285
286 let result = check_response_header(&Some(ResponseHeader {
287 status: Some(PbStatus {
288 status_code: StatusCode::Success as u32,
289 err_msg: String::default(),
290 }),
291 }));
292 assert!(result.is_ok());
293
294 let result = check_response_header(&Some(ResponseHeader {
295 status: Some(PbStatus {
296 status_code: u32::MAX,
297 err_msg: String::default(),
298 }),
299 }));
300 assert!(matches!(
301 result.unwrap_err(),
302 IllegalDatabaseResponse { .. }
303 ));
304
305 let result = check_response_header(&Some(ResponseHeader {
306 status: Some(PbStatus {
307 status_code: StatusCode::Internal as u32,
308 err_msg: "blabla".to_string(),
309 }),
310 }));
311 let Server { code, msg, .. } = result.unwrap_err() else {
312 unreachable!()
313 };
314 assert_eq!(code, StatusCode::Internal);
315 assert_eq!(msg, "blabla");
316 }
317}