common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40#[derive(Debug, Clone)]
41pub enum FlightMessage {
42    Schema(SchemaRef),
43    RecordBatch(DfRecordBatch),
44    AffectedRows(usize),
45    Metrics(String),
46}
47
48pub struct FlightEncoder {
49    write_options: writer::IpcWriteOptions,
50    data_gen: writer::IpcDataGenerator,
51    dictionary_tracker: writer::DictionaryTracker,
52}
53
54impl Default for FlightEncoder {
55    fn default() -> Self {
56        let write_options = writer::IpcWriteOptions::default()
57            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
58            .unwrap();
59
60        Self {
61            write_options,
62            data_gen: writer::IpcDataGenerator::default(),
63            dictionary_tracker: writer::DictionaryTracker::new(false),
64        }
65    }
66}
67
68impl FlightEncoder {
69    /// Creates new [FlightEncoder] with compression disabled.
70    pub fn with_compression_disabled() -> Self {
71        let write_options = writer::IpcWriteOptions::default()
72            .try_with_compression(None)
73            .unwrap();
74
75        Self {
76            write_options,
77            data_gen: writer::IpcDataGenerator::default(),
78            dictionary_tracker: writer::DictionaryTracker::new(false),
79        }
80    }
81
82    /// Encode the Arrow schema to [FlightData].
83    pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
84        SchemaAsIpc::new(schema, &self.write_options).into()
85    }
86
87    /// Encode the [FlightMessage] to a list (at least one element) of [FlightData]s.
88    ///
89    /// Normally only when the [FlightMessage] is an Arrow [RecordBatch] with dictionary arrays
90    /// will the encoder produce more than one [FlightData]s. Other types of [FlightMessage] should
91    /// be encoded to exactly one [FlightData].
92    pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
93        match flight_message {
94            FlightMessage::Schema(schema) => {
95                schema.fields().iter().for_each(|x| {
96                    if matches!(x.data_type(), DataType::Dictionary(_, _)) {
97                        self.dictionary_tracker.next_dict_id();
98                    }
99                });
100
101                vec1![self.encode_schema(schema.as_ref())]
102            }
103            FlightMessage::RecordBatch(record_batch) => {
104                let (encoded_dictionaries, encoded_batch) = self
105                    .data_gen
106                    .encode(
107                        &record_batch,
108                        &mut self.dictionary_tracker,
109                        &self.write_options,
110                        &mut Default::default(),
111                    )
112                    .expect("DictionaryTracker configured above to not fail on replacement");
113
114                Vec1::from_vec_push(
115                    encoded_dictionaries.into_iter().map(Into::into).collect(),
116                    encoded_batch.into(),
117                )
118            }
119            FlightMessage::AffectedRows(rows) => {
120                let metadata = FlightMetadata {
121                    affected_rows: Some(AffectedRows { value: rows as _ }),
122                    metrics: None,
123                }
124                .encode_to_vec();
125                vec1![FlightData {
126                    flight_descriptor: None,
127                    data_header: build_none_flight_msg().into(),
128                    app_metadata: metadata.into(),
129                    data_body: ProstBytes::default(),
130                }]
131            }
132            FlightMessage::Metrics(s) => {
133                let metadata = FlightMetadata {
134                    affected_rows: None,
135                    metrics: Some(Metrics {
136                        metrics: s.as_bytes().to_vec(),
137                    }),
138                }
139                .encode_to_vec();
140                vec1![FlightData {
141                    flight_descriptor: None,
142                    data_header: build_none_flight_msg().into(),
143                    app_metadata: metadata.into(),
144                    data_body: ProstBytes::default(),
145                }]
146            }
147        }
148    }
149}
150
151#[derive(Default)]
152pub struct FlightDecoder {
153    schema: Option<SchemaRef>,
154    schema_bytes: Option<bytes::Bytes>,
155    dictionaries_by_id: HashMap<i64, ArrayRef>,
156}
157
158impl FlightDecoder {
159    /// Build a [FlightDecoder] instance from provided schema bytes.
160    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
161        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
162            .context(error::ArrowSnafu)?;
163        Ok(Self {
164            schema: Some(Arc::new(arrow_schema)),
165            schema_bytes: Some(schema_bytes.clone()),
166            dictionaries_by_id: HashMap::new(),
167        })
168    }
169
170    pub fn try_decode_record_batch(
171        &mut self,
172        data_header: &bytes::Bytes,
173        data_body: &bytes::Bytes,
174    ) -> Result<DfRecordBatch> {
175        let schema = self
176            .schema
177            .as_ref()
178            .context(InvalidFlightDataSnafu {
179                reason: "Should have decoded schema first!",
180            })?
181            .clone();
182        let message = root_as_message(&data_header[..])
183            .map_err(|err| {
184                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
185            })
186            .context(error::ArrowSnafu)?;
187        let result = message
188            .header_as_record_batch()
189            .ok_or_else(|| {
190                ArrowError::ParseError(
191                    "Unable to convert flight data header to a record batch".to_string(),
192                )
193            })
194            .and_then(|batch| {
195                reader::read_record_batch(
196                    &Buffer::from(data_body.as_ref()),
197                    batch,
198                    schema,
199                    &HashMap::new(),
200                    None,
201                    &message.version(),
202                )
203            })
204            .context(error::ArrowSnafu)?;
205        Ok(result)
206    }
207
208    /// Try to decode the [FlightData] to a [FlightMessage].
209    ///
210    /// If the [FlightData] is of type `DictionaryBatch` (produced while encoding an Arrow
211    /// [RecordBatch] with dictionary arrays), the decoder will not return any [FlightMessage]s.
212    /// Instead, it will update its internal dictionary cache. Other types of [FlightData] will
213    /// be decoded to exactly one [FlightMessage].
214    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
215        let message = root_as_message(&flight_data.data_header).map_err(|e| {
216            InvalidFlightDataSnafu {
217                reason: e.to_string(),
218            }
219            .build()
220        })?;
221        match message.header_type() {
222            MessageHeader::NONE => {
223                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
224                    .context(DecodeFlightDataSnafu)?;
225                if let Some(AffectedRows { value }) = metadata.affected_rows {
226                    return Ok(Some(FlightMessage::AffectedRows(value as _)));
227                }
228                if let Some(Metrics { metrics }) = metadata.metrics {
229                    return Ok(Some(FlightMessage::Metrics(
230                        String::from_utf8_lossy(&metrics).to_string(),
231                    )));
232                }
233                InvalidFlightDataSnafu {
234                    reason: "Expecting FlightMetadata have some meaningful content.",
235                }
236                .fail()
237            }
238            MessageHeader::Schema => {
239                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
240                    InvalidFlightDataSnafu {
241                        reason: e.to_string(),
242                    }
243                    .build()
244                })?);
245                self.schema = Some(arrow_schema.clone());
246                self.schema_bytes = Some(flight_data.data_header.clone());
247                Ok(Some(FlightMessage::Schema(arrow_schema)))
248            }
249            MessageHeader::RecordBatch => {
250                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
251                    reason: "Should have decoded schema first!",
252                })?;
253                let arrow_batch = flight_data_to_arrow_batch(
254                    flight_data,
255                    schema.clone(),
256                    &self.dictionaries_by_id,
257                )
258                .map_err(|e| {
259                    InvalidFlightDataSnafu {
260                        reason: e.to_string(),
261                    }
262                    .build()
263                })?;
264                Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
265            }
266            MessageHeader::DictionaryBatch => {
267                let dictionary_batch =
268                    message
269                        .header_as_dictionary_batch()
270                        .context(InvalidFlightDataSnafu {
271                            reason: "could not get dictionary batch from DictionaryBatch message",
272                        })?;
273
274                let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
275                    reason: "schema message is not present previously",
276                })?;
277
278                reader::read_dictionary(
279                    &flight_data.data_body.clone().into(),
280                    dictionary_batch,
281                    schema,
282                    &mut self.dictionaries_by_id,
283                    &message.version(),
284                )
285                .context(error::ArrowSnafu)?;
286                Ok(None)
287            }
288            other => {
289                let name = other.variant_name().unwrap_or("UNKNOWN");
290                InvalidFlightDataSnafu {
291                    reason: format!("Unsupported FlightData type: {name}"),
292                }
293                .fail()
294            }
295        }
296    }
297
298    pub fn schema(&self) -> Option<&SchemaRef> {
299        self.schema.as_ref()
300    }
301
302    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
303        self.schema_bytes.clone()
304    }
305}
306
307pub fn flight_messages_to_recordbatches(
308    messages: Vec<FlightMessage>,
309) -> Result<Vec<DfRecordBatch>> {
310    if messages.is_empty() {
311        Ok(vec![])
312    } else {
313        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
314
315        match &messages[0] {
316            FlightMessage::Schema(_schema) => {}
317            _ => {
318                return InvalidFlightDataSnafu {
319                    reason: "First Flight Message must be schema!",
320                }
321                .fail();
322            }
323        };
324
325        for message in messages.into_iter().skip(1) {
326            match message {
327                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
328                _ => {
329                    return InvalidFlightDataSnafu {
330                        reason: "Expect the following Flight Messages are all Recordbatches!",
331                    }
332                    .fail();
333                }
334            }
335        }
336
337        Ok(recordbatches)
338    }
339}
340
341fn build_none_flight_msg() -> Bytes {
342    let mut builder = FlatBufferBuilder::new();
343
344    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
345    message.add_version(arrow::ipc::MetadataVersion::V5);
346    message.add_header_type(MessageHeader::NONE);
347    message.add_bodyLength(0);
348
349    let data = message.finish();
350    builder.finish(data, None);
351
352    builder.finished_data().into()
353}
354
355#[cfg(test)]
356mod test {
357    use arrow_flight::utils::batches_to_flight_data;
358    use datatypes::arrow::array::{
359        DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
360    };
361    use datatypes::arrow::datatypes::{DataType, Field, Schema};
362
363    use super::*;
364    use crate::Error;
365
366    #[test]
367    fn test_try_decode() -> Result<()> {
368        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
369            "n",
370            DataType::Int32,
371            true,
372        )]));
373
374        let batch1 = DfRecordBatch::try_new(
375            schema.clone(),
376            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
377        )
378        .unwrap();
379        let batch2 = DfRecordBatch::try_new(
380            schema.clone(),
381            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
382        )
383        .unwrap();
384
385        let flight_data =
386            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
387        assert_eq!(flight_data.len(), 3);
388        let [d1, d2, d3] = flight_data.as_slice() else {
389            unreachable!()
390        };
391
392        let decoder = &mut FlightDecoder::default();
393        assert!(decoder.schema.is_none());
394
395        let result = decoder.try_decode(d2);
396        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
397        assert!(
398            result
399                .unwrap_err()
400                .to_string()
401                .contains("Should have decoded schema first!")
402        );
403
404        let message = decoder.try_decode(d1)?.unwrap();
405        assert!(matches!(message, FlightMessage::Schema(_)));
406        let FlightMessage::Schema(decoded_schema) = message else {
407            unreachable!()
408        };
409        assert_eq!(decoded_schema, schema);
410
411        let _ = decoder.schema.as_ref().unwrap();
412
413        let message = decoder.try_decode(d2)?.unwrap();
414        assert!(matches!(message, FlightMessage::RecordBatch(_)));
415        let FlightMessage::RecordBatch(actual_batch) = message else {
416            unreachable!()
417        };
418        assert_eq!(actual_batch, batch1);
419
420        let message = decoder.try_decode(d3)?.unwrap();
421        assert!(matches!(message, FlightMessage::RecordBatch(_)));
422        let FlightMessage::RecordBatch(actual_batch) = message else {
423            unreachable!()
424        };
425        assert_eq!(actual_batch, batch2);
426        Ok(())
427    }
428
429    #[test]
430    fn test_flight_messages_to_recordbatches() {
431        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
432        let batch1 = DfRecordBatch::try_new(
433            schema.clone(),
434            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
435        )
436        .unwrap();
437        let batch2 = DfRecordBatch::try_new(
438            schema.clone(),
439            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
440        )
441        .unwrap();
442        let recordbatches = vec![batch1.clone(), batch2.clone()];
443
444        let m1 = FlightMessage::Schema(schema);
445        let m2 = FlightMessage::RecordBatch(batch1);
446        let m3 = FlightMessage::RecordBatch(batch2);
447
448        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
449        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
450        assert!(
451            result
452                .unwrap_err()
453                .to_string()
454                .contains("First Flight Message must be schema!")
455        );
456
457        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
458        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
459        assert!(
460            result
461                .unwrap_err()
462                .to_string()
463                .contains("Expect the following Flight Messages are all Recordbatches!")
464        );
465
466        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
467        assert_eq!(actual, recordbatches);
468    }
469
470    #[test]
471    fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
472        let schema = Arc::new(Schema::new(vec![
473            Field::new("i", DataType::UInt8, true),
474            Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
475        ]));
476        let batch1 = DfRecordBatch::try_new(
477            schema.clone(),
478            vec![
479                Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
480                Arc::new(DictionaryArray::new(
481                    UInt32Array::from_value(0, 3),
482                    Arc::new(StringArray::from_iter_values(["x"])),
483                )) as _,
484            ],
485        )
486        .unwrap();
487        let batch2 = DfRecordBatch::try_new(
488            schema.clone(),
489            vec![
490                Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
491                Arc::new(DictionaryArray::new(
492                    UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
493                    Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
494                )) as _,
495            ],
496        )
497        .unwrap();
498
499        let message_1 = FlightMessage::Schema(schema.clone());
500        let message_2 = FlightMessage::RecordBatch(batch1);
501        let message_3 = FlightMessage::RecordBatch(batch2);
502
503        let mut encoder = FlightEncoder::default();
504        let encoded_1 = encoder.encode(message_1);
505        let encoded_2 = encoder.encode(message_2);
506        let encoded_3 = encoder.encode(message_3);
507        // message 1 is Arrow Schema, should be encoded to one FlightData:
508        assert_eq!(encoded_1.len(), 1);
509        // message 2 and 3 are Arrow RecordBatch with dictionary arrays, should be encoded to
510        // multiple FlightData:
511        assert_eq!(encoded_2.len(), 2);
512        assert_eq!(encoded_3.len(), 2);
513
514        let mut decoder = FlightDecoder::default();
515        let decoded_1 = decoder.try_decode(encoded_1.first())?;
516        let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
517            unreachable!()
518        };
519        assert_eq!(actual_schema, schema);
520        let decoded_2 = decoder.try_decode(&encoded_2[0])?;
521        // expected to be a dictionary batch message, decoder should return none:
522        assert!(decoded_2.is_none());
523        let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
524            unreachable!()
525        };
526        let decoded_3 = decoder.try_decode(&encoded_3[0])?;
527        // expected to be a dictionary batch message, decoder should return none:
528        assert!(decoded_3.is_none());
529        let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
530            unreachable!()
531        };
532        let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
533            .unwrap()
534            .to_string();
535        let expected = r"
536+---+---+
537| i | s |
538+---+---+
539| 1 | x |
540| 2 | x |
541| 3 | x |
542| 4 | h |
543| 5 | e |
544| 6 | l |
545| 7 | l |
546| 8 | o |
547+---+---+";
548        assert_eq!(actual, expected.trim());
549        Ok(())
550    }
551}