client/
region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::RegionRequest;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::BoxedError;
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing::Span;
34use common_telemetry::tracing_context::TracingContext;
35use prost::Message;
36use query::query_engine::DefaultSerializer;
37use snafu::{OptionExt, ResultExt, location};
38use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
39use tokio_stream::StreamExt;
40
41use crate::error::{
42    self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
43    IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
44};
45use crate::{Client, Error, metrics};
46
47#[derive(Debug)]
48pub struct RegionRequester {
49    client: Client,
50    send_compression: bool,
51    accept_compression: bool,
52}
53
54#[async_trait]
55impl Datanode for RegionRequester {
56    async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
57        self.handle_inner(request).await.map_err(|err| {
58            if err.should_retry() {
59                meta_error::Error::RetryLater {
60                    source: BoxedError::new(err),
61                    clean_poisons: false,
62                }
63            } else {
64                meta_error::Error::External {
65                    source: BoxedError::new(err),
66                    location: location!(),
67                }
68            }
69        })
70    }
71
72    async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
73        let plan = DFLogicalSubstraitConvertor
74            .encode(&request.plan, DefaultSerializer)
75            .map_err(BoxedError::new)
76            .context(meta_error::ExternalSnafu)?
77            .to_vec();
78        let request = api::v1::region::QueryRequest {
79            header: request.header,
80            region_id: request.region_id.as_u64(),
81            plan,
82        };
83
84        let ticket = Ticket {
85            ticket: request.encode_to_vec().into(),
86        };
87        self.do_get_inner(ticket)
88            .await
89            .map_err(BoxedError::new)
90            .context(meta_error::ExternalSnafu)
91    }
92}
93
94impl RegionRequester {
95    pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
96        Self {
97            client,
98            send_compression,
99            accept_compression,
100        }
101    }
102
103    pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
104        let mut flight_client = self
105            .client
106            .make_flight_client(self.send_compression, self.accept_compression)?;
107        let response = flight_client
108            .mut_inner()
109            .do_get(ticket)
110            .await
111            .or_else(|e| {
112                let tonic_code = e.code();
113                let e: error::Error = e.into();
114                error!(
115                    e; "Failed to do Flight get, addr: {}, code: {}",
116                    flight_client.addr(),
117                    tonic_code
118                );
119                Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
120                    addr: flight_client.addr().to_string(),
121                    tonic_code,
122                })
123            })?;
124
125        let flight_data_stream = response.into_inner();
126        let mut decoder = FlightDecoder::default();
127
128        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
129            flight_data
130                .map_err(Error::from)
131                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
132                .context(IllegalFlightMessagesSnafu {
133                    reason: "none message",
134                })
135        });
136
137        let Some(first_flight_message) = flight_message_stream.next().await else {
138            return IllegalFlightMessagesSnafu {
139                reason: "Expect the response not to be empty",
140            }
141            .fail();
142        };
143        let FlightMessage::Schema(schema) = first_flight_message? else {
144            return IllegalFlightMessagesSnafu {
145                reason: "Expect schema to be the first flight message",
146            }
147            .fail();
148        };
149
150        let metrics = Arc::new(ArcSwapOption::from(None));
151        let metrics_ref = metrics.clone();
152
153        let tracing_context = TracingContext::from_current_span();
154
155        let schema = Arc::new(
156            datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
157        );
158        let schema_cloned = schema.clone();
159        let stream = Box::pin(stream!({
160            let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
161                "poll_flight_data_stream"
162            ));
163
164            let mut buffered_message: Option<FlightMessage> = None;
165            let mut stream_ended = false;
166
167            while !stream_ended {
168                // get the next message from the buffered message or read from the flight message stream
169                let flight_message_item = if let Some(msg) = buffered_message.take() {
170                    Some(Ok(msg))
171                } else {
172                    flight_message_stream.next().await
173                };
174
175                let flight_message = match flight_message_item {
176                    Some(Ok(message)) => message,
177                    Some(Err(e)) => {
178                        yield Err(BoxedError::new(e)).context(ExternalSnafu);
179                        break;
180                    }
181                    None => break,
182                };
183
184                match flight_message {
185                    FlightMessage::RecordBatch(record_batch) => {
186                        let result_to_yield =
187                            RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
188
189                        // get the next message from the stream. normally it should be a metrics message.
190                        if let Some(next_flight_message_result) = flight_message_stream.next().await
191                        {
192                            match next_flight_message_result {
193                                Ok(FlightMessage::Metrics(s)) => {
194                                    let m = serde_json::from_str(&s).ok().map(Arc::new);
195                                    metrics_ref.swap(m);
196                                }
197                                Ok(FlightMessage::RecordBatch(rb)) => {
198                                    // for some reason it's not a metrics message, so we need to buffer this record batch
199                                    // and yield it in the next iteration.
200                                    buffered_message = Some(FlightMessage::RecordBatch(rb));
201                                }
202                                Ok(_) => {
203                                    yield IllegalFlightMessagesSnafu {
204                                        reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
205                                    }
206                                    .fail()
207                                    .map_err(BoxedError::new)
208                                    .context(ExternalSnafu);
209                                    break;
210                                }
211                                Err(e) => {
212                                    yield Err(BoxedError::new(e)).context(ExternalSnafu);
213                                    break;
214                                }
215                            }
216                        } else {
217                            // the stream has ended
218                            stream_ended = true;
219                        }
220
221                        yield Ok(result_to_yield);
222                    }
223                    FlightMessage::Metrics(s) => {
224                        // just a branch in case of some metrics message comes after other things.
225                        let m = serde_json::from_str(&s).ok().map(Arc::new);
226                        metrics_ref.swap(m);
227                        break;
228                    }
229                    _ => {
230                        yield IllegalFlightMessagesSnafu {
231                            reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
232                        }
233                        .fail()
234                        .map_err(BoxedError::new)
235                        .context(ExternalSnafu);
236                        break;
237                    }
238                }
239            }
240        }));
241        let record_batch_stream = RecordBatchStreamWrapper {
242            schema,
243            stream,
244            output_ordering: None,
245            metrics,
246            span: Span::current(),
247        };
248        Ok(Box::pin(record_batch_stream))
249    }
250
251    async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
252        let request_type = request
253            .body
254            .as_ref()
255            .with_context(|| MissingFieldSnafu { field: "body" })?
256            .as_ref()
257            .to_string();
258        let _timer = metrics::METRIC_REGION_REQUEST_GRPC
259            .with_label_values(&[request_type.as_str()])
260            .start_timer();
261
262        let (addr, mut client) = self.client.raw_region_client()?;
263
264        let response = client
265            .handle(request)
266            .await
267            .map_err(|e| {
268                let code = e.code();
269                // Uses `Error::RegionServer` instead of `Error::Server`
270                error::Error::RegionServer {
271                    addr,
272                    code,
273                    source: BoxedError::new(error::Error::from(e)),
274                    location: location!(),
275                }
276            })?
277            .into_inner();
278
279        check_response_header(&response.header)?;
280
281        Ok(RegionResponse::from_region_response(response))
282    }
283
284    pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
285        self.handle_inner(request).await
286    }
287}
288
289pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
290    let status = header
291        .as_ref()
292        .and_then(|header| header.status.as_ref())
293        .context(IllegalDatabaseResponseSnafu {
294            err_msg: "either response header or status is missing",
295        })?;
296
297    if StatusCode::is_success(status.status_code) {
298        Ok(())
299    } else {
300        let code =
301            StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
302                err_msg: format!("unknown server status: {:?}", status),
303            })?;
304        ServerSnafu {
305            code,
306            msg: status.err_msg.clone(),
307        }
308        .fail()
309    }
310}
311
312#[cfg(test)]
313mod test {
314    use api::v1::Status as PbStatus;
315
316    use super::*;
317    use crate::Error::{IllegalDatabaseResponse, Server};
318
319    #[test]
320    fn test_check_response_header() {
321        let result = check_response_header(&None);
322        assert!(matches!(
323            result.unwrap_err(),
324            IllegalDatabaseResponse { .. }
325        ));
326
327        let result = check_response_header(&Some(ResponseHeader { status: None }));
328        assert!(matches!(
329            result.unwrap_err(),
330            IllegalDatabaseResponse { .. }
331        ));
332
333        let result = check_response_header(&Some(ResponseHeader {
334            status: Some(PbStatus {
335                status_code: StatusCode::Success as u32,
336                err_msg: String::default(),
337            }),
338        }));
339        assert!(result.is_ok());
340
341        let result = check_response_header(&Some(ResponseHeader {
342            status: Some(PbStatus {
343                status_code: u32::MAX,
344                err_msg: String::default(),
345            }),
346        }));
347        assert!(matches!(
348            result.unwrap_err(),
349            IllegalDatabaseResponse { .. }
350        ));
351
352        let result = check_response_header(&Some(ResponseHeader {
353            status: Some(PbStatus {
354                status_code: StatusCode::Internal as u32,
355                err_msg: "blabla".to_string(),
356            }),
357        }));
358        let Server { code, msg, .. } = result.unwrap_err() else {
359            unreachable!()
360        };
361        assert_eq!(code, StatusCode::Internal);
362        assert_eq!(msg, "blabla");
363    }
364}