1use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::region::RegionRequest;
19use api::v1::ResponseHeader;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::{BoxedError, ErrorExt};
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{location, OptionExt, ResultExt};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41 self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42 IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{metrics, Client, Error};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48 client: Client,
49}
50
51#[async_trait]
52impl Datanode for RegionRequester {
53 async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
54 self.handle_inner(request).await.map_err(|err| {
55 if err.should_retry() {
56 meta_error::Error::RetryLater {
57 source: BoxedError::new(err),
58 }
59 } else {
60 meta_error::Error::External {
61 source: BoxedError::new(err),
62 location: location!(),
63 }
64 }
65 })
66 }
67
68 async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
69 let plan = DFLogicalSubstraitConvertor
70 .encode(&request.plan, DefaultSerializer)
71 .map_err(BoxedError::new)
72 .context(meta_error::ExternalSnafu)?
73 .to_vec();
74 let request = api::v1::region::QueryRequest {
75 header: request.header,
76 region_id: request.region_id.as_u64(),
77 plan,
78 };
79
80 let ticket = Ticket {
81 ticket: request.encode_to_vec().into(),
82 };
83 self.do_get_inner(ticket)
84 .await
85 .map_err(BoxedError::new)
86 .context(meta_error::ExternalSnafu)
87 }
88}
89
90impl RegionRequester {
91 pub fn new(client: Client) -> Self {
92 Self { client }
93 }
94
95 pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
96 let mut flight_client = self.client.make_flight_client()?;
97 let response = flight_client
98 .mut_inner()
99 .do_get(ticket)
100 .await
101 .map_err(|e| {
102 let tonic_code = e.code();
103 let e: error::Error = e.into();
104 let code = e.status_code();
105 let msg = e.to_string();
106 let error = ServerSnafu { code, msg }
107 .fail::<()>()
108 .map_err(BoxedError::new)
109 .with_context(|_| FlightGetSnafu {
110 tonic_code,
111 addr: flight_client.addr().to_string(),
112 })
113 .unwrap_err();
114 error!(
115 e; "Failed to do Flight get, addr: {}, code: {}",
116 flight_client.addr(),
117 tonic_code
118 );
119 error
120 })?;
121
122 let flight_data_stream = response.into_inner();
123 let mut decoder = FlightDecoder::default();
124
125 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
126 flight_data
127 .map_err(Error::from)
128 .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu))
129 });
130
131 let Some(first_flight_message) = flight_message_stream.next().await else {
132 return IllegalFlightMessagesSnafu {
133 reason: "Expect the response not to be empty",
134 }
135 .fail();
136 };
137 let FlightMessage::Schema(schema) = first_flight_message? else {
138 return IllegalFlightMessagesSnafu {
139 reason: "Expect schema to be the first flight message",
140 }
141 .fail();
142 };
143
144 let metrics = Arc::new(ArcSwapOption::from(None));
145 let metrics_ref = metrics.clone();
146
147 let tracing_context = TracingContext::from_current_span();
148
149 let stream = Box::pin(stream!({
150 let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
151 "poll_flight_data_stream"
152 ));
153 while let Some(flight_message) = flight_message_stream.next().await {
154 let flight_message = flight_message
155 .map_err(BoxedError::new)
156 .context(ExternalSnafu)?;
157
158 match flight_message {
159 FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
160 FlightMessage::Metrics(s) => {
161 let m = serde_json::from_str(&s).ok().map(Arc::new);
162 metrics_ref.swap(m);
163 break;
164 }
165 _ => {
166 yield IllegalFlightMessagesSnafu {
167 reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
168 }
169 .fail()
170 .map_err(BoxedError::new)
171 .context(ExternalSnafu);
172 break;
173 }
174 }
175 }
176 }));
177 let record_batch_stream = RecordBatchStreamWrapper {
178 schema,
179 stream,
180 output_ordering: None,
181 metrics,
182 };
183 Ok(Box::pin(record_batch_stream))
184 }
185
186 async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
187 let request_type = request
188 .body
189 .as_ref()
190 .with_context(|| MissingFieldSnafu { field: "body" })?
191 .as_ref()
192 .to_string();
193 let _timer = metrics::METRIC_REGION_REQUEST_GRPC
194 .with_label_values(&[request_type.as_str()])
195 .start_timer();
196
197 let (addr, mut client) = self.client.raw_region_client()?;
198
199 let response = client
200 .handle(request)
201 .await
202 .map_err(|e| {
203 let code = e.code();
204 error::Error::RegionServer {
206 addr,
207 code,
208 source: BoxedError::new(error::Error::from(e)),
209 location: location!(),
210 }
211 })?
212 .into_inner();
213
214 check_response_header(&response.header)?;
215
216 Ok(RegionResponse::from_region_response(response))
217 }
218
219 pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
220 self.handle_inner(request).await
221 }
222}
223
224pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
225 let status = header
226 .as_ref()
227 .and_then(|header| header.status.as_ref())
228 .context(IllegalDatabaseResponseSnafu {
229 err_msg: "either response header or status is missing",
230 })?;
231
232 if StatusCode::is_success(status.status_code) {
233 Ok(())
234 } else {
235 let code =
236 StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
237 err_msg: format!("unknown server status: {:?}", status),
238 })?;
239 ServerSnafu {
240 code,
241 msg: status.err_msg.clone(),
242 }
243 .fail()
244 }
245}
246
247#[cfg(test)]
248mod test {
249 use api::v1::Status as PbStatus;
250
251 use super::*;
252 use crate::Error::{IllegalDatabaseResponse, Server};
253
254 #[test]
255 fn test_check_response_header() {
256 let result = check_response_header(&None);
257 assert!(matches!(
258 result.unwrap_err(),
259 IllegalDatabaseResponse { .. }
260 ));
261
262 let result = check_response_header(&Some(ResponseHeader { status: None }));
263 assert!(matches!(
264 result.unwrap_err(),
265 IllegalDatabaseResponse { .. }
266 ));
267
268 let result = check_response_header(&Some(ResponseHeader {
269 status: Some(PbStatus {
270 status_code: StatusCode::Success as u32,
271 err_msg: String::default(),
272 }),
273 }));
274 assert!(result.is_ok());
275
276 let result = check_response_header(&Some(ResponseHeader {
277 status: Some(PbStatus {
278 status_code: u32::MAX,
279 err_msg: String::default(),
280 }),
281 }));
282 assert!(matches!(
283 result.unwrap_err(),
284 IllegalDatabaseResponse { .. }
285 ));
286
287 let result = check_response_header(&Some(ResponseHeader {
288 status: Some(PbStatus {
289 status_code: StatusCode::Internal as u32,
290 err_msg: "blabla".to_string(),
291 }),
292 }));
293 let Server { code, msg, .. } = result.unwrap_err() else {
294 unreachable!()
295 };
296 assert_eq!(code, StatusCode::Internal);
297 assert_eq!(msg, "blabla");
298 }
299}