client/
region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::region::RegionRequest;
19use api::v1::ResponseHeader;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::{BoxedError, ErrorExt};
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{location, OptionExt, ResultExt};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41    self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42    IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{metrics, Client, Error};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48    client: Client,
49}
50
51#[async_trait]
52impl Datanode for RegionRequester {
53    async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
54        self.handle_inner(request).await.map_err(|err| {
55            if err.should_retry() {
56                meta_error::Error::RetryLater {
57                    source: BoxedError::new(err),
58                }
59            } else {
60                meta_error::Error::External {
61                    source: BoxedError::new(err),
62                    location: location!(),
63                }
64            }
65        })
66    }
67
68    async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
69        let plan = DFLogicalSubstraitConvertor
70            .encode(&request.plan, DefaultSerializer)
71            .map_err(BoxedError::new)
72            .context(meta_error::ExternalSnafu)?
73            .to_vec();
74        let request = api::v1::region::QueryRequest {
75            header: request.header,
76            region_id: request.region_id.as_u64(),
77            plan,
78        };
79
80        let ticket = Ticket {
81            ticket: request.encode_to_vec().into(),
82        };
83        self.do_get_inner(ticket)
84            .await
85            .map_err(BoxedError::new)
86            .context(meta_error::ExternalSnafu)
87    }
88}
89
90impl RegionRequester {
91    pub fn new(client: Client) -> Self {
92        Self { client }
93    }
94
95    pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
96        let mut flight_client = self.client.make_flight_client()?;
97        let response = flight_client
98            .mut_inner()
99            .do_get(ticket)
100            .await
101            .map_err(|e| {
102                let tonic_code = e.code();
103                let e: error::Error = e.into();
104                let code = e.status_code();
105                let msg = e.to_string();
106                let error = ServerSnafu { code, msg }
107                    .fail::<()>()
108                    .map_err(BoxedError::new)
109                    .with_context(|_| FlightGetSnafu {
110                        tonic_code,
111                        addr: flight_client.addr().to_string(),
112                    })
113                    .unwrap_err();
114                error!(
115                    e; "Failed to do Flight get, addr: {}, code: {}",
116                    flight_client.addr(),
117                    tonic_code
118                );
119                error
120            })?;
121
122        let flight_data_stream = response.into_inner();
123        let mut decoder = FlightDecoder::default();
124
125        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
126            flight_data
127                .map_err(Error::from)
128                .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu))
129        });
130
131        let Some(first_flight_message) = flight_message_stream.next().await else {
132            return IllegalFlightMessagesSnafu {
133                reason: "Expect the response not to be empty",
134            }
135            .fail();
136        };
137        let FlightMessage::Schema(schema) = first_flight_message? else {
138            return IllegalFlightMessagesSnafu {
139                reason: "Expect schema to be the first flight message",
140            }
141            .fail();
142        };
143
144        let metrics = Arc::new(ArcSwapOption::from(None));
145        let metrics_ref = metrics.clone();
146
147        let tracing_context = TracingContext::from_current_span();
148
149        let stream = Box::pin(stream!({
150            let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
151                "poll_flight_data_stream"
152            ));
153            while let Some(flight_message) = flight_message_stream.next().await {
154                let flight_message = flight_message
155                    .map_err(BoxedError::new)
156                    .context(ExternalSnafu)?;
157
158                match flight_message {
159                    FlightMessage::Recordbatch(record_batch) => yield Ok(record_batch),
160                    FlightMessage::Metrics(s) => {
161                        let m = serde_json::from_str(&s).ok().map(Arc::new);
162                        metrics_ref.swap(m);
163                        break;
164                    }
165                    _ => {
166                        yield IllegalFlightMessagesSnafu {
167                            reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
168                        }
169                        .fail()
170                        .map_err(BoxedError::new)
171                        .context(ExternalSnafu);
172                        break;
173                    }
174                }
175            }
176        }));
177        let record_batch_stream = RecordBatchStreamWrapper {
178            schema,
179            stream,
180            output_ordering: None,
181            metrics,
182        };
183        Ok(Box::pin(record_batch_stream))
184    }
185
186    async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
187        let request_type = request
188            .body
189            .as_ref()
190            .with_context(|| MissingFieldSnafu { field: "body" })?
191            .as_ref()
192            .to_string();
193        let _timer = metrics::METRIC_REGION_REQUEST_GRPC
194            .with_label_values(&[request_type.as_str()])
195            .start_timer();
196
197        let (addr, mut client) = self.client.raw_region_client()?;
198
199        let response = client
200            .handle(request)
201            .await
202            .map_err(|e| {
203                let code = e.code();
204                // Uses `Error::RegionServer` instead of `Error::Server`
205                error::Error::RegionServer {
206                    addr,
207                    code,
208                    source: BoxedError::new(error::Error::from(e)),
209                    location: location!(),
210                }
211            })?
212            .into_inner();
213
214        check_response_header(&response.header)?;
215
216        Ok(RegionResponse::from_region_response(response))
217    }
218
219    pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
220        self.handle_inner(request).await
221    }
222}
223
224pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
225    let status = header
226        .as_ref()
227        .and_then(|header| header.status.as_ref())
228        .context(IllegalDatabaseResponseSnafu {
229            err_msg: "either response header or status is missing",
230        })?;
231
232    if StatusCode::is_success(status.status_code) {
233        Ok(())
234    } else {
235        let code =
236            StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
237                err_msg: format!("unknown server status: {:?}", status),
238            })?;
239        ServerSnafu {
240            code,
241            msg: status.err_msg.clone(),
242        }
243        .fail()
244    }
245}
246
247#[cfg(test)]
248mod test {
249    use api::v1::Status as PbStatus;
250
251    use super::*;
252    use crate::Error::{IllegalDatabaseResponse, Server};
253
254    #[test]
255    fn test_check_response_header() {
256        let result = check_response_header(&None);
257        assert!(matches!(
258            result.unwrap_err(),
259            IllegalDatabaseResponse { .. }
260        ));
261
262        let result = check_response_header(&Some(ResponseHeader { status: None }));
263        assert!(matches!(
264            result.unwrap_err(),
265            IllegalDatabaseResponse { .. }
266        ));
267
268        let result = check_response_header(&Some(ResponseHeader {
269            status: Some(PbStatus {
270                status_code: StatusCode::Success as u32,
271                err_msg: String::default(),
272            }),
273        }));
274        assert!(result.is_ok());
275
276        let result = check_response_header(&Some(ResponseHeader {
277            status: Some(PbStatus {
278                status_code: u32::MAX,
279                err_msg: String::default(),
280            }),
281        }));
282        assert!(matches!(
283            result.unwrap_err(),
284            IllegalDatabaseResponse { .. }
285        ));
286
287        let result = check_response_header(&Some(ResponseHeader {
288            status: Some(PbStatus {
289                status_code: StatusCode::Internal as u32,
290                err_msg: "blabla".to_string(),
291            }),
292        }));
293        let Server { code, msg, .. } = result.unwrap_err() else {
294            unreachable!()
295        };
296        assert_eq!(code, StatusCode::Internal);
297        assert_eq!(msg, "blabla");
298    }
299}