client/
region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::region::RegionRequest;
19use api::v1::ResponseHeader;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::{BoxedError, ErrorExt};
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{location, OptionExt, ResultExt};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41    self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42    IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{metrics, Client, Error};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48    client: Client,
49    send_compression: bool,
50    accept_compression: bool,
51}
52
53#[async_trait]
54impl Datanode for RegionRequester {
55    async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
56        self.handle_inner(request).await.map_err(|err| {
57            if err.should_retry() {
58                meta_error::Error::RetryLater {
59                    source: BoxedError::new(err),
60                    clean_poisons: false,
61                }
62            } else {
63                meta_error::Error::External {
64                    source: BoxedError::new(err),
65                    location: location!(),
66                }
67            }
68        })
69    }
70
71    async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
72        let plan = DFLogicalSubstraitConvertor
73            .encode(&request.plan, DefaultSerializer)
74            .map_err(BoxedError::new)
75            .context(meta_error::ExternalSnafu)?
76            .to_vec();
77        let request = api::v1::region::QueryRequest {
78            header: request.header,
79            region_id: request.region_id.as_u64(),
80            plan,
81        };
82
83        let ticket = Ticket {
84            ticket: request.encode_to_vec().into(),
85        };
86        self.do_get_inner(ticket)
87            .await
88            .map_err(BoxedError::new)
89            .context(meta_error::ExternalSnafu)
90    }
91}
92
93impl RegionRequester {
94    pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
95        Self {
96            client,
97            send_compression,
98            accept_compression,
99        }
100    }
101
102    pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
103        let mut flight_client = self
104            .client
105            .make_flight_client(self.send_compression, self.accept_compression)?;
106        let response = flight_client
107            .mut_inner()
108            .do_get(ticket)
109            .await
110            .map_err(|e| {
111                let tonic_code = e.code();
112                let e: error::Error = e.into();
113                let code = e.status_code();
114                let msg = e.to_string();
115                let error = ServerSnafu { code, msg }
116                    .fail::<()>()
117                    .map_err(BoxedError::new)
118                    .with_context(|_| FlightGetSnafu {
119                        tonic_code,
120                        addr: flight_client.addr().to_string(),
121                    })
122                    .unwrap_err();
123                error!(
124                    e; "Failed to do Flight get, addr: {}, code: {}",
125                    flight_client.addr(),
126                    tonic_code
127                );
128                error
129            })?;
130
131        let flight_data_stream = response.into_inner();
132        let mut decoder = FlightDecoder::default();
133
134        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
135            flight_data
136                .map_err(Error::from)
137                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))
138        });
139
140        let Some(first_flight_message) = flight_message_stream.next().await else {
141            return IllegalFlightMessagesSnafu {
142                reason: "Expect the response not to be empty",
143            }
144            .fail();
145        };
146        let FlightMessage::Schema(schema) = first_flight_message? else {
147            return IllegalFlightMessagesSnafu {
148                reason: "Expect schema to be the first flight message",
149            }
150            .fail();
151        };
152
153        let metrics = Arc::new(ArcSwapOption::from(None));
154        let metrics_ref = metrics.clone();
155
156        let tracing_context = TracingContext::from_current_span();
157
158        let schema = Arc::new(
159            datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
160        );
161        let schema_cloned = schema.clone();
162        let stream = Box::pin(stream!({
163            let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
164                "poll_flight_data_stream"
165            ));
166            while let Some(flight_message) = flight_message_stream.next().await {
167                let flight_message = flight_message
168                    .map_err(BoxedError::new)
169                    .context(ExternalSnafu)?;
170
171                match flight_message {
172                    FlightMessage::RecordBatch(record_batch) => {
173                        yield RecordBatch::try_from_df_record_batch(
174                            schema_cloned.clone(),
175                            record_batch,
176                        )
177                    }
178                    FlightMessage::Metrics(s) => {
179                        let m = serde_json::from_str(&s).ok().map(Arc::new);
180                        metrics_ref.swap(m);
181                        break;
182                    }
183                    _ => {
184                        yield IllegalFlightMessagesSnafu {
185                            reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
186                        }
187                        .fail()
188                        .map_err(BoxedError::new)
189                        .context(ExternalSnafu);
190                        break;
191                    }
192                }
193            }
194        }));
195        let record_batch_stream = RecordBatchStreamWrapper {
196            schema,
197            stream,
198            output_ordering: None,
199            metrics,
200        };
201        Ok(Box::pin(record_batch_stream))
202    }
203
204    async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
205        let request_type = request
206            .body
207            .as_ref()
208            .with_context(|| MissingFieldSnafu { field: "body" })?
209            .as_ref()
210            .to_string();
211        let _timer = metrics::METRIC_REGION_REQUEST_GRPC
212            .with_label_values(&[request_type.as_str()])
213            .start_timer();
214
215        let (addr, mut client) = self.client.raw_region_client()?;
216
217        let response = client
218            .handle(request)
219            .await
220            .map_err(|e| {
221                let code = e.code();
222                // Uses `Error::RegionServer` instead of `Error::Server`
223                error::Error::RegionServer {
224                    addr,
225                    code,
226                    source: BoxedError::new(error::Error::from(e)),
227                    location: location!(),
228                }
229            })?
230            .into_inner();
231
232        check_response_header(&response.header)?;
233
234        Ok(RegionResponse::from_region_response(response))
235    }
236
237    pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
238        self.handle_inner(request).await
239    }
240}
241
242pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
243    let status = header
244        .as_ref()
245        .and_then(|header| header.status.as_ref())
246        .context(IllegalDatabaseResponseSnafu {
247            err_msg: "either response header or status is missing",
248        })?;
249
250    if StatusCode::is_success(status.status_code) {
251        Ok(())
252    } else {
253        let code =
254            StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
255                err_msg: format!("unknown server status: {:?}", status),
256            })?;
257        ServerSnafu {
258            code,
259            msg: status.err_msg.clone(),
260        }
261        .fail()
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use api::v1::Status as PbStatus;
268
269    use super::*;
270    use crate::Error::{IllegalDatabaseResponse, Server};
271
272    #[test]
273    fn test_check_response_header() {
274        let result = check_response_header(&None);
275        assert!(matches!(
276            result.unwrap_err(),
277            IllegalDatabaseResponse { .. }
278        ));
279
280        let result = check_response_header(&Some(ResponseHeader { status: None }));
281        assert!(matches!(
282            result.unwrap_err(),
283            IllegalDatabaseResponse { .. }
284        ));
285
286        let result = check_response_header(&Some(ResponseHeader {
287            status: Some(PbStatus {
288                status_code: StatusCode::Success as u32,
289                err_msg: String::default(),
290            }),
291        }));
292        assert!(result.is_ok());
293
294        let result = check_response_header(&Some(ResponseHeader {
295            status: Some(PbStatus {
296                status_code: u32::MAX,
297                err_msg: String::default(),
298            }),
299        }));
300        assert!(matches!(
301            result.unwrap_err(),
302            IllegalDatabaseResponse { .. }
303        ));
304
305        let result = check_response_header(&Some(ResponseHeader {
306            status: Some(PbStatus {
307                status_code: StatusCode::Internal as u32,
308                err_msg: "blabla".to_string(),
309            }),
310        }));
311        let Server { code, msg, .. } = result.unwrap_err() else {
312            unreachable!()
313        };
314        assert_eq!(code, StatusCode::Internal);
315        assert_eq!(msg, "blabla");
316    }
317}