common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::{RecordBatch, RecordBatches};
25use datatypes::arrow;
26use datatypes::arrow::datatypes::Schema as ArrowSchema;
27use datatypes::arrow::ipc::{root_as_message, writer, MessageHeader};
28use datatypes::schema::{Schema, SchemaRef};
29use flatbuffers::FlatBufferBuilder;
30use prost::bytes::Bytes as ProstBytes;
31use prost::Message;
32use snafu::{OptionExt, ResultExt};
33
34use crate::error::{
35    ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu,
36    Result,
37};
38
39#[derive(Debug, Clone)]
40pub enum FlightMessage {
41    Schema(SchemaRef),
42    Recordbatch(RecordBatch),
43    AffectedRows(usize),
44    Metrics(String),
45}
46
47pub struct FlightEncoder {
48    write_options: writer::IpcWriteOptions,
49    data_gen: writer::IpcDataGenerator,
50    dictionary_tracker: writer::DictionaryTracker,
51}
52
53impl Default for FlightEncoder {
54    fn default() -> Self {
55        let write_options = writer::IpcWriteOptions::default()
56            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
57            .unwrap();
58
59        Self {
60            write_options,
61            data_gen: writer::IpcDataGenerator::default(),
62            dictionary_tracker: writer::DictionaryTracker::new(false),
63        }
64    }
65}
66
67impl FlightEncoder {
68    pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
69        match flight_message {
70            FlightMessage::Schema(schema) => {
71                SchemaAsIpc::new(schema.arrow_schema(), &self.write_options).into()
72            }
73            FlightMessage::Recordbatch(recordbatch) => {
74                let (encoded_dictionaries, encoded_batch) = self
75                    .data_gen
76                    .encoded_batch(
77                        recordbatch.df_record_batch(),
78                        &mut self.dictionary_tracker,
79                        &self.write_options,
80                    )
81                    .expect("DictionaryTracker configured above to not fail on replacement");
82
83                // TODO(LFC): Handle dictionary as FlightData here, when we supported Arrow's Dictionary DataType.
84                // Currently we don't have a datatype corresponding to Arrow's Dictionary DataType,
85                // so there won't be any "dictionaries" here. Assert to be sure about it, and
86                // perform a "testing guard" in case we forgot to handle the possible "dictionaries"
87                // here in the future.
88                debug_assert_eq!(encoded_dictionaries.len(), 0);
89
90                encoded_batch.into()
91            }
92            FlightMessage::AffectedRows(rows) => {
93                let metadata = FlightMetadata {
94                    affected_rows: Some(AffectedRows { value: rows as _ }),
95                    metrics: None,
96                }
97                .encode_to_vec();
98                FlightData {
99                    flight_descriptor: None,
100                    data_header: build_none_flight_msg().into(),
101                    app_metadata: metadata.into(),
102                    data_body: ProstBytes::default(),
103                }
104            }
105            FlightMessage::Metrics(s) => {
106                let metadata = FlightMetadata {
107                    affected_rows: None,
108                    metrics: Some(Metrics {
109                        metrics: s.as_bytes().to_vec(),
110                    }),
111                }
112                .encode_to_vec();
113                FlightData {
114                    flight_descriptor: None,
115                    data_header: build_none_flight_msg().into(),
116                    app_metadata: metadata.into(),
117                    data_body: ProstBytes::default(),
118                }
119            }
120        }
121    }
122}
123
124#[derive(Default)]
125pub struct FlightDecoder {
126    schema: Option<SchemaRef>,
127}
128
129impl FlightDecoder {
130    pub fn try_decode(&mut self, flight_data: FlightData) -> Result<FlightMessage> {
131        let message = root_as_message(&flight_data.data_header).map_err(|e| {
132            InvalidFlightDataSnafu {
133                reason: e.to_string(),
134            }
135            .build()
136        })?;
137        match message.header_type() {
138            MessageHeader::NONE => {
139                let metadata = FlightMetadata::decode(flight_data.app_metadata)
140                    .context(DecodeFlightDataSnafu)?;
141                if let Some(AffectedRows { value }) = metadata.affected_rows {
142                    return Ok(FlightMessage::AffectedRows(value as _));
143                }
144                if let Some(Metrics { metrics }) = metadata.metrics {
145                    return Ok(FlightMessage::Metrics(
146                        String::from_utf8_lossy(&metrics).to_string(),
147                    ));
148                }
149                InvalidFlightDataSnafu {
150                    reason: "Expecting FlightMetadata have some meaningful content.",
151                }
152                .fail()
153            }
154            MessageHeader::Schema => {
155                let arrow_schema = ArrowSchema::try_from(&flight_data).map_err(|e| {
156                    InvalidFlightDataSnafu {
157                        reason: e.to_string(),
158                    }
159                    .build()
160                })?;
161                let schema =
162                    Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?);
163
164                self.schema = Some(schema.clone());
165
166                Ok(FlightMessage::Schema(schema))
167            }
168            MessageHeader::RecordBatch => {
169                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
170                    reason: "Should have decoded schema first!",
171                })?;
172                let arrow_schema = schema.arrow_schema().clone();
173
174                let arrow_batch =
175                    flight_data_to_arrow_batch(&flight_data, arrow_schema, &HashMap::new())
176                        .map_err(|e| {
177                            InvalidFlightDataSnafu {
178                                reason: e.to_string(),
179                            }
180                            .build()
181                        })?;
182                let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch)
183                    .context(CreateRecordBatchSnafu)?;
184                Ok(FlightMessage::Recordbatch(recordbatch))
185            }
186            other => {
187                let name = other.variant_name().unwrap_or("UNKNOWN");
188                InvalidFlightDataSnafu {
189                    reason: format!("Unsupported FlightData type: {name}"),
190                }
191                .fail()
192            }
193        }
194    }
195}
196
197pub fn flight_messages_to_recordbatches(messages: Vec<FlightMessage>) -> Result<RecordBatches> {
198    if messages.is_empty() {
199        Ok(RecordBatches::empty())
200    } else {
201        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
202
203        let schema = match &messages[0] {
204            FlightMessage::Schema(schema) => schema.clone(),
205            _ => {
206                return InvalidFlightDataSnafu {
207                    reason: "First Flight Message must be schema!",
208                }
209                .fail()
210            }
211        };
212
213        for message in messages.into_iter().skip(1) {
214            match message {
215                FlightMessage::Recordbatch(recordbatch) => recordbatches.push(recordbatch),
216                _ => {
217                    return InvalidFlightDataSnafu {
218                        reason: "Expect the following Flight Messages are all Recordbatches!",
219                    }
220                    .fail()
221                }
222            }
223        }
224
225        RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu)
226    }
227}
228
229fn build_none_flight_msg() -> Bytes {
230    let mut builder = FlatBufferBuilder::new();
231
232    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
233    message.add_version(arrow::ipc::MetadataVersion::V5);
234    message.add_header_type(MessageHeader::NONE);
235    message.add_bodyLength(0);
236
237    let data = message.finish();
238    builder.finish(data, None);
239
240    builder.finished_data().into()
241}
242
243#[cfg(test)]
244mod test {
245    use arrow_flight::utils::batches_to_flight_data;
246    use datatypes::arrow::datatypes::{DataType, Field};
247    use datatypes::prelude::ConcreteDataType;
248    use datatypes::schema::ColumnSchema;
249    use datatypes::vectors::Int32Vector;
250
251    use super::*;
252    use crate::Error;
253
254    #[test]
255    fn test_try_decode() {
256        let arrow_schema = ArrowSchema::new(vec![Field::new("n", DataType::Int32, true)]);
257        let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap());
258
259        let batch1 = RecordBatch::new(
260            schema.clone(),
261            vec![Arc::new(Int32Vector::from(vec![Some(1), None, Some(3)])) as _],
262        )
263        .unwrap();
264        let batch2 = RecordBatch::new(
265            schema.clone(),
266            vec![Arc::new(Int32Vector::from(vec![None, Some(5)])) as _],
267        )
268        .unwrap();
269
270        let flight_data = batches_to_flight_data(
271            &arrow_schema,
272            vec![
273                batch1.clone().into_df_record_batch(),
274                batch2.clone().into_df_record_batch(),
275            ],
276        )
277        .unwrap();
278        assert_eq!(flight_data.len(), 3);
279        let [d1, d2, d3] = flight_data.as_slice() else {
280            unreachable!()
281        };
282
283        let decoder = &mut FlightDecoder::default();
284        assert!(decoder.schema.is_none());
285
286        let result = decoder.try_decode(d2.clone());
287        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
288        assert!(result
289            .unwrap_err()
290            .to_string()
291            .contains("Should have decoded schema first!"));
292
293        let message = decoder.try_decode(d1.clone()).unwrap();
294        assert!(matches!(message, FlightMessage::Schema(_)));
295        let FlightMessage::Schema(decoded_schema) = message else {
296            unreachable!()
297        };
298        assert_eq!(decoded_schema, schema);
299
300        let _ = decoder.schema.as_ref().unwrap();
301
302        let message = decoder.try_decode(d2.clone()).unwrap();
303        assert!(matches!(message, FlightMessage::Recordbatch(_)));
304        let FlightMessage::Recordbatch(actual_batch) = message else {
305            unreachable!()
306        };
307        assert_eq!(actual_batch, batch1);
308
309        let message = decoder.try_decode(d3.clone()).unwrap();
310        assert!(matches!(message, FlightMessage::Recordbatch(_)));
311        let FlightMessage::Recordbatch(actual_batch) = message else {
312            unreachable!()
313        };
314        assert_eq!(actual_batch, batch2);
315    }
316
317    #[test]
318    fn test_flight_messages_to_recordbatches() {
319        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
320            "m",
321            ConcreteDataType::int32_datatype(),
322            true,
323        )]));
324        let batch1 = RecordBatch::new(
325            schema.clone(),
326            vec![Arc::new(Int32Vector::from(vec![Some(2), None, Some(4)])) as _],
327        )
328        .unwrap();
329        let batch2 = RecordBatch::new(
330            schema.clone(),
331            vec![Arc::new(Int32Vector::from(vec![None, Some(6)])) as _],
332        )
333        .unwrap();
334        let recordbatches =
335            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
336
337        let m1 = FlightMessage::Schema(schema);
338        let m2 = FlightMessage::Recordbatch(batch1);
339        let m3 = FlightMessage::Recordbatch(batch2);
340
341        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
342        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
343        assert!(result
344            .unwrap_err()
345            .to_string()
346            .contains("First Flight Message must be schema!"));
347
348        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
349        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
350        assert!(result
351            .unwrap_err()
352            .to_string()
353            .contains("Expect the following Flight Messages are all Recordbatches!"));
354
355        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
356        assert_eq!(actual, recordbatches);
357    }
358}