common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
31use flatbuffers::FlatBufferBuilder;
32use prost::bytes::Bytes as ProstBytes;
33use prost::Message;
34use snafu::{OptionExt, ResultExt};
35use vec1::{vec1, Vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40#[derive(Debug, Clone)]
41pub enum FlightMessage {
42    Schema(SchemaRef),
43    RecordBatch(DfRecordBatch),
44    AffectedRows(usize),
45    Metrics(String),
46}
47
48pub struct FlightEncoder {
49    write_options: writer::IpcWriteOptions,
50    data_gen: writer::IpcDataGenerator,
51    dictionary_tracker: writer::DictionaryTracker,
52}
53
54impl Default for FlightEncoder {
55    fn default() -> Self {
56        let write_options = writer::IpcWriteOptions::default()
57            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
58            .unwrap();
59
60        Self {
61            write_options,
62            data_gen: writer::IpcDataGenerator::default(),
63            dictionary_tracker: writer::DictionaryTracker::new(false),
64        }
65    }
66}
67
68impl FlightEncoder {
69    /// Creates new [FlightEncoder] with compression disabled.
70    pub fn with_compression_disabled() -> Self {
71        let write_options = writer::IpcWriteOptions::default()
72            .try_with_compression(None)
73            .unwrap();
74
75        Self {
76            write_options,
77            data_gen: writer::IpcDataGenerator::default(),
78            dictionary_tracker: writer::DictionaryTracker::new(false),
79        }
80    }
81
82    /// Encode the Arrow schema to [FlightData].
83    pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
84        SchemaAsIpc::new(schema, &self.write_options).into()
85    }
86
87    /// Encode the [FlightMessage] to a list (at least one element) of [FlightData]s.
88    ///
89    /// Normally only when the [FlightMessage] is an Arrow [RecordBatch] with dictionary arrays
90    /// will the encoder produce more than one [FlightData]s. Other types of [FlightMessage] should
91    /// be encoded to exactly one [FlightData].
92    pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
93        match flight_message {
94            FlightMessage::Schema(schema) => {
95                schema.fields().iter().for_each(|x| {
96                    if matches!(x.data_type(), DataType::Dictionary(_, _)) {
97                        self.dictionary_tracker.next_dict_id();
98                    }
99                });
100
101                vec1![self.encode_schema(schema.as_ref())]
102            }
103            FlightMessage::RecordBatch(record_batch) => {
104                let (encoded_dictionaries, encoded_batch) = self
105                    .data_gen
106                    .encoded_batch(
107                        &record_batch,
108                        &mut self.dictionary_tracker,
109                        &self.write_options,
110                    )
111                    .expect("DictionaryTracker configured above to not fail on replacement");
112
113                Vec1::from_vec_push(
114                    encoded_dictionaries.into_iter().map(Into::into).collect(),
115                    encoded_batch.into(),
116                )
117            }
118            FlightMessage::AffectedRows(rows) => {
119                let metadata = FlightMetadata {
120                    affected_rows: Some(AffectedRows { value: rows as _ }),
121                    metrics: None,
122                }
123                .encode_to_vec();
124                vec1![FlightData {
125                    flight_descriptor: None,
126                    data_header: build_none_flight_msg().into(),
127                    app_metadata: metadata.into(),
128                    data_body: ProstBytes::default(),
129                }]
130            }
131            FlightMessage::Metrics(s) => {
132                let metadata = FlightMetadata {
133                    affected_rows: None,
134                    metrics: Some(Metrics {
135                        metrics: s.as_bytes().to_vec(),
136                    }),
137                }
138                .encode_to_vec();
139                vec1![FlightData {
140                    flight_descriptor: None,
141                    data_header: build_none_flight_msg().into(),
142                    app_metadata: metadata.into(),
143                    data_body: ProstBytes::default(),
144                }]
145            }
146        }
147    }
148}
149
150#[derive(Default)]
151pub struct FlightDecoder {
152    schema: Option<SchemaRef>,
153    schema_bytes: Option<bytes::Bytes>,
154    dictionaries_by_id: HashMap<i64, ArrayRef>,
155}
156
157impl FlightDecoder {
158    /// Build a [FlightDecoder] instance from provided schema bytes.
159    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
160        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
161            .context(error::ArrowSnafu)?;
162        Ok(Self {
163            schema: Some(Arc::new(arrow_schema)),
164            schema_bytes: Some(schema_bytes.clone()),
165            dictionaries_by_id: HashMap::new(),
166        })
167    }
168
169    pub fn try_decode_record_batch(
170        &mut self,
171        data_header: &bytes::Bytes,
172        data_body: &bytes::Bytes,
173    ) -> Result<DfRecordBatch> {
174        let schema = self
175            .schema
176            .as_ref()
177            .context(InvalidFlightDataSnafu {
178                reason: "Should have decoded schema first!",
179            })?
180            .clone();
181        let message = root_as_message(&data_header[..])
182            .map_err(|err| {
183                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
184            })
185            .context(error::ArrowSnafu)?;
186        let result = message
187            .header_as_record_batch()
188            .ok_or_else(|| {
189                ArrowError::ParseError(
190                    "Unable to convert flight data header to a record batch".to_string(),
191                )
192            })
193            .and_then(|batch| {
194                reader::read_record_batch(
195                    &Buffer::from(data_body.as_ref()),
196                    batch,
197                    schema,
198                    &HashMap::new(),
199                    None,
200                    &message.version(),
201                )
202            })
203            .context(error::ArrowSnafu)?;
204        Ok(result)
205    }
206
207    /// Try to decode the [FlightData] to a [FlightMessage].
208    ///
209    /// If the [FlightData] is of type `DictionaryBatch` (produced while encoding an Arrow
210    /// [RecordBatch] with dictionary arrays), the decoder will not return any [FlightMessage]s.
211    /// Instead, it will update its internal dictionary cache. Other types of [FlightData] will
212    /// be decoded to exactly one [FlightMessage].
213    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
214        let message = root_as_message(&flight_data.data_header).map_err(|e| {
215            InvalidFlightDataSnafu {
216                reason: e.to_string(),
217            }
218            .build()
219        })?;
220        match message.header_type() {
221            MessageHeader::NONE => {
222                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
223                    .context(DecodeFlightDataSnafu)?;
224                if let Some(AffectedRows { value }) = metadata.affected_rows {
225                    return Ok(Some(FlightMessage::AffectedRows(value as _)));
226                }
227                if let Some(Metrics { metrics }) = metadata.metrics {
228                    return Ok(Some(FlightMessage::Metrics(
229                        String::from_utf8_lossy(&metrics).to_string(),
230                    )));
231                }
232                InvalidFlightDataSnafu {
233                    reason: "Expecting FlightMetadata have some meaningful content.",
234                }
235                .fail()
236            }
237            MessageHeader::Schema => {
238                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
239                    InvalidFlightDataSnafu {
240                        reason: e.to_string(),
241                    }
242                    .build()
243                })?);
244                self.schema = Some(arrow_schema.clone());
245                self.schema_bytes = Some(flight_data.data_header.clone());
246                Ok(Some(FlightMessage::Schema(arrow_schema)))
247            }
248            MessageHeader::RecordBatch => {
249                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
250                    reason: "Should have decoded schema first!",
251                })?;
252                let arrow_batch = flight_data_to_arrow_batch(
253                    flight_data,
254                    schema.clone(),
255                    &self.dictionaries_by_id,
256                )
257                .map_err(|e| {
258                    InvalidFlightDataSnafu {
259                        reason: e.to_string(),
260                    }
261                    .build()
262                })?;
263                Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
264            }
265            MessageHeader::DictionaryBatch => {
266                let dictionary_batch =
267                    message
268                        .header_as_dictionary_batch()
269                        .context(InvalidFlightDataSnafu {
270                            reason: "could not get dictionary batch from DictionaryBatch message",
271                        })?;
272
273                let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
274                    reason: "schema message is not present previously",
275                })?;
276
277                reader::read_dictionary(
278                    &flight_data.data_body.clone().into(),
279                    dictionary_batch,
280                    schema,
281                    &mut self.dictionaries_by_id,
282                    &message.version(),
283                )
284                .context(error::ArrowSnafu)?;
285                Ok(None)
286            }
287            other => {
288                let name = other.variant_name().unwrap_or("UNKNOWN");
289                InvalidFlightDataSnafu {
290                    reason: format!("Unsupported FlightData type: {name}"),
291                }
292                .fail()
293            }
294        }
295    }
296
297    pub fn schema(&self) -> Option<&SchemaRef> {
298        self.schema.as_ref()
299    }
300
301    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
302        self.schema_bytes.clone()
303    }
304}
305
306pub fn flight_messages_to_recordbatches(
307    messages: Vec<FlightMessage>,
308) -> Result<Vec<DfRecordBatch>> {
309    if messages.is_empty() {
310        Ok(vec![])
311    } else {
312        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
313
314        match &messages[0] {
315            FlightMessage::Schema(_schema) => {}
316            _ => {
317                return InvalidFlightDataSnafu {
318                    reason: "First Flight Message must be schema!",
319                }
320                .fail()
321            }
322        };
323
324        for message in messages.into_iter().skip(1) {
325            match message {
326                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
327                _ => {
328                    return InvalidFlightDataSnafu {
329                        reason: "Expect the following Flight Messages are all Recordbatches!",
330                    }
331                    .fail()
332                }
333            }
334        }
335
336        Ok(recordbatches)
337    }
338}
339
340fn build_none_flight_msg() -> Bytes {
341    let mut builder = FlatBufferBuilder::new();
342
343    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
344    message.add_version(arrow::ipc::MetadataVersion::V5);
345    message.add_header_type(MessageHeader::NONE);
346    message.add_bodyLength(0);
347
348    let data = message.finish();
349    builder.finish(data, None);
350
351    builder.finished_data().into()
352}
353
354#[cfg(test)]
355mod test {
356    use arrow_flight::utils::batches_to_flight_data;
357    use datatypes::arrow::array::{
358        DictionaryArray, Int32Array, StringArray, UInt32Array, UInt8Array,
359    };
360    use datatypes::arrow::datatypes::{DataType, Field, Schema};
361
362    use super::*;
363    use crate::Error;
364
365    #[test]
366    fn test_try_decode() -> Result<()> {
367        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
368            "n",
369            DataType::Int32,
370            true,
371        )]));
372
373        let batch1 = DfRecordBatch::try_new(
374            schema.clone(),
375            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
376        )
377        .unwrap();
378        let batch2 = DfRecordBatch::try_new(
379            schema.clone(),
380            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
381        )
382        .unwrap();
383
384        let flight_data =
385            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
386        assert_eq!(flight_data.len(), 3);
387        let [d1, d2, d3] = flight_data.as_slice() else {
388            unreachable!()
389        };
390
391        let decoder = &mut FlightDecoder::default();
392        assert!(decoder.schema.is_none());
393
394        let result = decoder.try_decode(d2);
395        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
396        assert!(result
397            .unwrap_err()
398            .to_string()
399            .contains("Should have decoded schema first!"));
400
401        let message = decoder.try_decode(d1)?.unwrap();
402        assert!(matches!(message, FlightMessage::Schema(_)));
403        let FlightMessage::Schema(decoded_schema) = message else {
404            unreachable!()
405        };
406        assert_eq!(decoded_schema, schema);
407
408        let _ = decoder.schema.as_ref().unwrap();
409
410        let message = decoder.try_decode(d2)?.unwrap();
411        assert!(matches!(message, FlightMessage::RecordBatch(_)));
412        let FlightMessage::RecordBatch(actual_batch) = message else {
413            unreachable!()
414        };
415        assert_eq!(actual_batch, batch1);
416
417        let message = decoder.try_decode(d3)?.unwrap();
418        assert!(matches!(message, FlightMessage::RecordBatch(_)));
419        let FlightMessage::RecordBatch(actual_batch) = message else {
420            unreachable!()
421        };
422        assert_eq!(actual_batch, batch2);
423        Ok(())
424    }
425
426    #[test]
427    fn test_flight_messages_to_recordbatches() {
428        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
429        let batch1 = DfRecordBatch::try_new(
430            schema.clone(),
431            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
432        )
433        .unwrap();
434        let batch2 = DfRecordBatch::try_new(
435            schema.clone(),
436            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
437        )
438        .unwrap();
439        let recordbatches = vec![batch1.clone(), batch2.clone()];
440
441        let m1 = FlightMessage::Schema(schema);
442        let m2 = FlightMessage::RecordBatch(batch1);
443        let m3 = FlightMessage::RecordBatch(batch2);
444
445        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
446        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
447        assert!(result
448            .unwrap_err()
449            .to_string()
450            .contains("First Flight Message must be schema!"));
451
452        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
453        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
454        assert!(result
455            .unwrap_err()
456            .to_string()
457            .contains("Expect the following Flight Messages are all Recordbatches!"));
458
459        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
460        assert_eq!(actual, recordbatches);
461    }
462
463    #[test]
464    fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
465        let schema = Arc::new(Schema::new(vec![
466            Field::new("i", DataType::UInt8, true),
467            Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
468        ]));
469        let batch1 = DfRecordBatch::try_new(
470            schema.clone(),
471            vec![
472                Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
473                Arc::new(DictionaryArray::new(
474                    UInt32Array::from_value(0, 3),
475                    Arc::new(StringArray::from_iter_values(["x"])),
476                )) as _,
477            ],
478        )
479        .unwrap();
480        let batch2 = DfRecordBatch::try_new(
481            schema.clone(),
482            vec![
483                Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
484                Arc::new(DictionaryArray::new(
485                    UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
486                    Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
487                )) as _,
488            ],
489        )
490        .unwrap();
491
492        let message_1 = FlightMessage::Schema(schema.clone());
493        let message_2 = FlightMessage::RecordBatch(batch1);
494        let message_3 = FlightMessage::RecordBatch(batch2);
495
496        let mut encoder = FlightEncoder::default();
497        let encoded_1 = encoder.encode(message_1);
498        let encoded_2 = encoder.encode(message_2);
499        let encoded_3 = encoder.encode(message_3);
500        // message 1 is Arrow Schema, should be encoded to one FlightData:
501        assert_eq!(encoded_1.len(), 1);
502        // message 2 and 3 are Arrow RecordBatch with dictionary arrays, should be encoded to
503        // multiple FlightData:
504        assert_eq!(encoded_2.len(), 2);
505        assert_eq!(encoded_3.len(), 2);
506
507        let mut decoder = FlightDecoder::default();
508        let decoded_1 = decoder.try_decode(encoded_1.first())?;
509        let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
510            unreachable!()
511        };
512        assert_eq!(actual_schema, schema);
513        let decoded_2 = decoder.try_decode(&encoded_2[0])?;
514        // expected to be a dictionary batch message, decoder should return none:
515        assert!(decoded_2.is_none());
516        let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
517            unreachable!()
518        };
519        let decoded_3 = decoder.try_decode(&encoded_3[0])?;
520        // expected to be a dictionary batch message, decoder should return none:
521        assert!(decoded_3.is_none());
522        let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
523            unreachable!()
524        };
525        let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
526            .unwrap()
527            .to_string();
528        let expected = r"
529+---+---+
530| i | s |
531+---+---+
532| 1 | x |
533| 2 | x |
534| 3 | x |
535| 4 | h |
536| 5 | e |
537| 6 | l |
538| 7 | l |
539| 8 | o |
540+---+---+";
541        assert_eq!(actual, expected.trim());
542        Ok(())
543    }
544}