client/
region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::RegionRequest;
20use arc_swap::ArcSwapOption;
21use arrow_flight::Ticket;
22use async_stream::stream;
23use async_trait::async_trait;
24use common_error::ext::BoxedError;
25use common_error::status_code::StatusCode;
26use common_grpc::flight::{FlightDecoder, FlightMessage};
27use common_meta::error::{self as meta_error, Result as MetaResult};
28use common_meta::node_manager::Datanode;
29use common_query::request::QueryRequest;
30use common_recordbatch::error::ExternalSnafu;
31use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use prost::Message;
35use query::query_engine::DefaultSerializer;
36use snafu::{OptionExt, ResultExt, location};
37use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
38use tokio_stream::StreamExt;
39
40use crate::error::{
41    self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
42    IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
43};
44use crate::{Client, Error, metrics};
45
46#[derive(Debug)]
47pub struct RegionRequester {
48    client: Client,
49    send_compression: bool,
50    accept_compression: bool,
51}
52
53#[async_trait]
54impl Datanode for RegionRequester {
55    async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
56        self.handle_inner(request).await.map_err(|err| {
57            if err.should_retry() {
58                meta_error::Error::RetryLater {
59                    source: BoxedError::new(err),
60                    clean_poisons: false,
61                }
62            } else {
63                meta_error::Error::External {
64                    source: BoxedError::new(err),
65                    location: location!(),
66                }
67            }
68        })
69    }
70
71    async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
72        let plan = DFLogicalSubstraitConvertor
73            .encode(&request.plan, DefaultSerializer)
74            .map_err(BoxedError::new)
75            .context(meta_error::ExternalSnafu)?
76            .to_vec();
77        let request = api::v1::region::QueryRequest {
78            header: request.header,
79            region_id: request.region_id.as_u64(),
80            plan,
81        };
82
83        let ticket = Ticket {
84            ticket: request.encode_to_vec().into(),
85        };
86        self.do_get_inner(ticket)
87            .await
88            .map_err(BoxedError::new)
89            .context(meta_error::ExternalSnafu)
90    }
91}
92
93impl RegionRequester {
94    pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
95        Self {
96            client,
97            send_compression,
98            accept_compression,
99        }
100    }
101
102    pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
103        let mut flight_client = self
104            .client
105            .make_flight_client(self.send_compression, self.accept_compression)?;
106        let response = flight_client
107            .mut_inner()
108            .do_get(ticket)
109            .await
110            .or_else(|e| {
111                let tonic_code = e.code();
112                let e: error::Error = e.into();
113                error!(
114                    e; "Failed to do Flight get, addr: {}, code: {}",
115                    flight_client.addr(),
116                    tonic_code
117                );
118                Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
119                    addr: flight_client.addr().to_string(),
120                    tonic_code,
121                })
122            })?;
123
124        let flight_data_stream = response.into_inner();
125        let mut decoder = FlightDecoder::default();
126
127        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
128            flight_data
129                .map_err(Error::from)
130                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
131                .context(IllegalFlightMessagesSnafu {
132                    reason: "none message",
133                })
134        });
135
136        let Some(first_flight_message) = flight_message_stream.next().await else {
137            return IllegalFlightMessagesSnafu {
138                reason: "Expect the response not to be empty",
139            }
140            .fail();
141        };
142        let FlightMessage::Schema(schema) = first_flight_message? else {
143            return IllegalFlightMessagesSnafu {
144                reason: "Expect schema to be the first flight message",
145            }
146            .fail();
147        };
148
149        let metrics = Arc::new(ArcSwapOption::from(None));
150        let metrics_ref = metrics.clone();
151
152        let tracing_context = TracingContext::from_current_span();
153
154        let schema = Arc::new(
155            datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
156        );
157        let schema_cloned = schema.clone();
158        let stream = Box::pin(stream!({
159            let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
160                "poll_flight_data_stream"
161            ));
162
163            let mut buffered_message: Option<FlightMessage> = None;
164            let mut stream_ended = false;
165
166            while !stream_ended {
167                // get the next message from the buffered message or read from the flight message stream
168                let flight_message_item = if let Some(msg) = buffered_message.take() {
169                    Some(Ok(msg))
170                } else {
171                    flight_message_stream.next().await
172                };
173
174                let flight_message = match flight_message_item {
175                    Some(Ok(message)) => message,
176                    Some(Err(e)) => {
177                        yield Err(BoxedError::new(e)).context(ExternalSnafu);
178                        break;
179                    }
180                    None => break,
181                };
182
183                match flight_message {
184                    FlightMessage::RecordBatch(record_batch) => {
185                        let result_to_yield =
186                            RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
187
188                        // get the next message from the stream. normally it should be a metrics message.
189                        if let Some(next_flight_message_result) = flight_message_stream.next().await
190                        {
191                            match next_flight_message_result {
192                                Ok(FlightMessage::Metrics(s)) => {
193                                    let m = serde_json::from_str(&s).ok().map(Arc::new);
194                                    metrics_ref.swap(m);
195                                }
196                                Ok(FlightMessage::RecordBatch(rb)) => {
197                                    // for some reason it's not a metrics message, so we need to buffer this record batch
198                                    // and yield it in the next iteration.
199                                    buffered_message = Some(FlightMessage::RecordBatch(rb));
200                                }
201                                Ok(_) => {
202                                    yield IllegalFlightMessagesSnafu {
203                                        reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
204                                    }
205                                    .fail()
206                                    .map_err(BoxedError::new)
207                                    .context(ExternalSnafu);
208                                    break;
209                                }
210                                Err(e) => {
211                                    yield Err(BoxedError::new(e)).context(ExternalSnafu);
212                                    break;
213                                }
214                            }
215                        } else {
216                            // the stream has ended
217                            stream_ended = true;
218                        }
219
220                        yield Ok(result_to_yield);
221                    }
222                    FlightMessage::Metrics(s) => {
223                        // just a branch in case of some metrics message comes after other things.
224                        let m = serde_json::from_str(&s).ok().map(Arc::new);
225                        metrics_ref.swap(m);
226                        break;
227                    }
228                    _ => {
229                        yield IllegalFlightMessagesSnafu {
230                            reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
231                        }
232                        .fail()
233                        .map_err(BoxedError::new)
234                        .context(ExternalSnafu);
235                        break;
236                    }
237                }
238            }
239        }));
240        let record_batch_stream = RecordBatchStreamWrapper {
241            schema,
242            stream,
243            output_ordering: None,
244            metrics,
245        };
246        Ok(Box::pin(record_batch_stream))
247    }
248
249    async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
250        let request_type = request
251            .body
252            .as_ref()
253            .with_context(|| MissingFieldSnafu { field: "body" })?
254            .as_ref()
255            .to_string();
256        let _timer = metrics::METRIC_REGION_REQUEST_GRPC
257            .with_label_values(&[request_type.as_str()])
258            .start_timer();
259
260        let (addr, mut client) = self.client.raw_region_client()?;
261
262        let response = client
263            .handle(request)
264            .await
265            .map_err(|e| {
266                let code = e.code();
267                // Uses `Error::RegionServer` instead of `Error::Server`
268                error::Error::RegionServer {
269                    addr,
270                    code,
271                    source: BoxedError::new(error::Error::from(e)),
272                    location: location!(),
273                }
274            })?
275            .into_inner();
276
277        check_response_header(&response.header)?;
278
279        Ok(RegionResponse::from_region_response(response))
280    }
281
282    pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
283        self.handle_inner(request).await
284    }
285}
286
287pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
288    let status = header
289        .as_ref()
290        .and_then(|header| header.status.as_ref())
291        .context(IllegalDatabaseResponseSnafu {
292            err_msg: "either response header or status is missing",
293        })?;
294
295    if StatusCode::is_success(status.status_code) {
296        Ok(())
297    } else {
298        let code =
299            StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
300                err_msg: format!("unknown server status: {:?}", status),
301            })?;
302        ServerSnafu {
303            code,
304            msg: status.err_msg.clone(),
305        }
306        .fail()
307    }
308}
309
310#[cfg(test)]
311mod test {
312    use api::v1::Status as PbStatus;
313
314    use super::*;
315    use crate::Error::{IllegalDatabaseResponse, Server};
316
317    #[test]
318    fn test_check_response_header() {
319        let result = check_response_header(&None);
320        assert!(matches!(
321            result.unwrap_err(),
322            IllegalDatabaseResponse { .. }
323        ));
324
325        let result = check_response_header(&Some(ResponseHeader { status: None }));
326        assert!(matches!(
327            result.unwrap_err(),
328            IllegalDatabaseResponse { .. }
329        ));
330
331        let result = check_response_header(&Some(ResponseHeader {
332            status: Some(PbStatus {
333                status_code: StatusCode::Success as u32,
334                err_msg: String::default(),
335            }),
336        }));
337        assert!(result.is_ok());
338
339        let result = check_response_header(&Some(ResponseHeader {
340            status: Some(PbStatus {
341                status_code: u32::MAX,
342                err_msg: String::default(),
343            }),
344        }));
345        assert!(matches!(
346            result.unwrap_err(),
347            IllegalDatabaseResponse { .. }
348        ));
349
350        let result = check_response_header(&Some(ResponseHeader {
351            status: Some(PbStatus {
352                status_code: StatusCode::Internal as u32,
353                err_msg: "blabla".to_string(),
354            }),
355        }));
356        let Server { code, msg, .. } = result.unwrap_err() else {
357            unreachable!()
358        };
359        assert_eq!(code, StatusCode::Internal);
360        assert_eq!(msg, "blabla");
361    }
362}