Skip to main content

common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40/// Flight metadata key used to carry flow query extensions as JSON pairs.
41pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions";
42/// Flight metadata key used to carry query snapshot read upper bounds as JSON.
43pub const SNAPSHOT_SEQS_METADATA_KEY: &str = "x-greptime-snapshot-seqs";
44
45#[derive(Debug, Clone)]
46pub enum FlightMessage {
47    Schema(SchemaRef),
48    RecordBatch(DfRecordBatch),
49    AffectedRows {
50        rows: usize,
51        metrics: Option<String>,
52    },
53    Metrics(String),
54}
55
56pub struct FlightEncoder {
57    write_options: writer::IpcWriteOptions,
58    data_gen: writer::IpcDataGenerator,
59    dictionary_tracker: writer::DictionaryTracker,
60}
61
62impl Default for FlightEncoder {
63    fn default() -> Self {
64        let write_options = writer::IpcWriteOptions::default()
65            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
66            .unwrap();
67
68        Self {
69            write_options,
70            data_gen: writer::IpcDataGenerator::default(),
71            dictionary_tracker: writer::DictionaryTracker::new(false),
72        }
73    }
74}
75
76impl FlightEncoder {
77    /// Creates new [FlightEncoder] with compression disabled.
78    pub fn with_compression_disabled() -> Self {
79        let write_options = writer::IpcWriteOptions::default()
80            .try_with_compression(None)
81            .unwrap();
82
83        Self {
84            write_options,
85            data_gen: writer::IpcDataGenerator::default(),
86            dictionary_tracker: writer::DictionaryTracker::new(false),
87        }
88    }
89
90    /// Encode the Arrow schema to [FlightData].
91    pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
92        SchemaAsIpc::new(schema, &self.write_options).into()
93    }
94
95    /// Encode the [FlightMessage] to a list (at least one element) of [FlightData]s.
96    ///
97    /// Normally only when the [FlightMessage] is an Arrow [RecordBatch] with dictionary arrays
98    /// will the encoder produce more than one [FlightData]s. Other types of [FlightMessage] should
99    /// be encoded to exactly one [FlightData].
100    pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
101        match flight_message {
102            FlightMessage::Schema(schema) => {
103                schema.fields().iter().for_each(|x| {
104                    if matches!(x.data_type(), DataType::Dictionary(_, _)) {
105                        self.dictionary_tracker.next_dict_id();
106                    }
107                });
108
109                vec1![self.encode_schema(schema.as_ref())]
110            }
111            FlightMessage::RecordBatch(record_batch) => {
112                let (encoded_dictionaries, encoded_batch) = self
113                    .data_gen
114                    .encode(
115                        &record_batch,
116                        &mut self.dictionary_tracker,
117                        &self.write_options,
118                        &mut Default::default(),
119                    )
120                    .expect("DictionaryTracker configured above to not fail on replacement");
121
122                Vec1::from_vec_push(
123                    encoded_dictionaries.into_iter().map(Into::into).collect(),
124                    encoded_batch.into(),
125                )
126            }
127            FlightMessage::AffectedRows { rows, metrics } => {
128                let metadata = FlightMetadata {
129                    affected_rows: Some(AffectedRows { value: rows as _ }),
130                    metrics: metrics.map(|s| Metrics {
131                        metrics: s.into_bytes(),
132                    }),
133                }
134                .encode_to_vec();
135                vec1![FlightData {
136                    flight_descriptor: None,
137                    data_header: build_none_flight_msg().into(),
138                    app_metadata: metadata.into(),
139                    data_body: ProstBytes::default(),
140                }]
141            }
142            FlightMessage::Metrics(s) => {
143                let metadata = FlightMetadata {
144                    affected_rows: None,
145                    metrics: Some(Metrics {
146                        metrics: s.as_bytes().to_vec(),
147                    }),
148                }
149                .encode_to_vec();
150                vec1![FlightData {
151                    flight_descriptor: None,
152                    data_header: build_none_flight_msg().into(),
153                    app_metadata: metadata.into(),
154                    data_body: ProstBytes::default(),
155                }]
156            }
157        }
158    }
159}
160
161#[derive(Default)]
162pub struct FlightDecoder {
163    schema: Option<SchemaRef>,
164    schema_bytes: Option<bytes::Bytes>,
165    dictionaries_by_id: HashMap<i64, ArrayRef>,
166}
167
168impl FlightDecoder {
169    /// Build a [FlightDecoder] instance from provided schema bytes.
170    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
171        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
172            .context(error::ArrowSnafu)?;
173        Ok(Self {
174            schema: Some(Arc::new(arrow_schema)),
175            schema_bytes: Some(schema_bytes.clone()),
176            dictionaries_by_id: HashMap::new(),
177        })
178    }
179
180    pub fn try_decode_record_batch(
181        &mut self,
182        data_header: &bytes::Bytes,
183        data_body: &bytes::Bytes,
184    ) -> Result<DfRecordBatch> {
185        let schema = self
186            .schema
187            .as_ref()
188            .context(InvalidFlightDataSnafu {
189                reason: "Should have decoded schema first!",
190            })?
191            .clone();
192        let message = root_as_message(&data_header[..])
193            .map_err(|err| {
194                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
195            })
196            .context(error::ArrowSnafu)?;
197        let result = message
198            .header_as_record_batch()
199            .ok_or_else(|| {
200                ArrowError::ParseError(
201                    "Unable to convert flight data header to a record batch".to_string(),
202                )
203            })
204            .and_then(|batch| {
205                reader::read_record_batch(
206                    &Buffer::from(data_body.as_ref()),
207                    batch,
208                    schema,
209                    &HashMap::new(),
210                    None,
211                    &message.version(),
212                )
213            })
214            .context(error::ArrowSnafu)?;
215        Ok(result)
216    }
217
218    /// Try to decode the [FlightData] to a [FlightMessage].
219    ///
220    /// If the [FlightData] is of type `DictionaryBatch` (produced while encoding an Arrow
221    /// [RecordBatch] with dictionary arrays), the decoder will not return any [FlightMessage]s.
222    /// Instead, it will update its internal dictionary cache. Other types of [FlightData] will
223    /// be decoded to exactly one [FlightMessage].
224    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
225        let message = root_as_message(&flight_data.data_header).map_err(|e| {
226            InvalidFlightDataSnafu {
227                reason: e.to_string(),
228            }
229            .build()
230        })?;
231        match message.header_type() {
232            MessageHeader::NONE => {
233                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
234                    .context(DecodeFlightDataSnafu)?;
235                if let Some(AffectedRows { value }) = metadata.affected_rows {
236                    return Ok(Some(FlightMessage::AffectedRows {
237                        rows: value as _,
238                        metrics: metadata
239                            .metrics
240                            .map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
241                    }));
242                }
243                if let Some(Metrics { metrics }) = metadata.metrics {
244                    return Ok(Some(FlightMessage::Metrics(
245                        String::from_utf8_lossy(&metrics).to_string(),
246                    )));
247                }
248                InvalidFlightDataSnafu {
249                    reason: "Expecting FlightMetadata have some meaningful content.",
250                }
251                .fail()
252            }
253            MessageHeader::Schema => {
254                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
255                    InvalidFlightDataSnafu {
256                        reason: e.to_string(),
257                    }
258                    .build()
259                })?);
260                self.schema = Some(arrow_schema.clone());
261                self.schema_bytes = Some(flight_data.data_header.clone());
262                Ok(Some(FlightMessage::Schema(arrow_schema)))
263            }
264            MessageHeader::RecordBatch => {
265                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
266                    reason: "Should have decoded schema first!",
267                })?;
268                let arrow_batch = flight_data_to_arrow_batch(
269                    flight_data,
270                    schema.clone(),
271                    &self.dictionaries_by_id,
272                )
273                .map_err(|e| {
274                    InvalidFlightDataSnafu {
275                        reason: e.to_string(),
276                    }
277                    .build()
278                })?;
279                Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
280            }
281            MessageHeader::DictionaryBatch => {
282                let dictionary_batch =
283                    message
284                        .header_as_dictionary_batch()
285                        .context(InvalidFlightDataSnafu {
286                            reason: "could not get dictionary batch from DictionaryBatch message",
287                        })?;
288
289                let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
290                    reason: "schema message is not present previously",
291                })?;
292
293                reader::read_dictionary(
294                    &flight_data.data_body.clone().into(),
295                    dictionary_batch,
296                    schema,
297                    &mut self.dictionaries_by_id,
298                    &message.version(),
299                )
300                .context(error::ArrowSnafu)?;
301                Ok(None)
302            }
303            other => {
304                let name = other.variant_name().unwrap_or("UNKNOWN");
305                InvalidFlightDataSnafu {
306                    reason: format!("Unsupported FlightData type: {name}"),
307                }
308                .fail()
309            }
310        }
311    }
312
313    pub fn schema(&self) -> Option<&SchemaRef> {
314        self.schema.as_ref()
315    }
316
317    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
318        self.schema_bytes.clone()
319    }
320}
321
322pub fn flight_messages_to_recordbatches(
323    messages: Vec<FlightMessage>,
324) -> Result<Vec<DfRecordBatch>> {
325    if messages.is_empty() {
326        Ok(vec![])
327    } else {
328        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
329
330        match &messages[0] {
331            FlightMessage::Schema(_schema) => {}
332            _ => {
333                return InvalidFlightDataSnafu {
334                    reason: "First Flight Message must be schema!",
335                }
336                .fail();
337            }
338        };
339
340        for message in messages.into_iter().skip(1) {
341            match message {
342                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
343                _ => {
344                    return InvalidFlightDataSnafu {
345                        reason: "Expect the following Flight Messages are all Recordbatches!",
346                    }
347                    .fail();
348                }
349            }
350        }
351
352        Ok(recordbatches)
353    }
354}
355
356fn build_none_flight_msg() -> Bytes {
357    let mut builder = FlatBufferBuilder::new();
358
359    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
360    message.add_version(arrow::ipc::MetadataVersion::V5);
361    message.add_header_type(MessageHeader::NONE);
362    message.add_bodyLength(0);
363
364    let data = message.finish();
365    builder.finish(data, None);
366
367    builder.finished_data().into()
368}
369
370#[cfg(test)]
371mod test {
372    use arrow_flight::utils::batches_to_flight_data;
373    use datatypes::arrow::array::{
374        DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
375    };
376    use datatypes::arrow::datatypes::{DataType, Field, Schema};
377
378    use super::*;
379    use crate::Error;
380
381    #[test]
382    fn test_try_decode() -> Result<()> {
383        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
384            "n",
385            DataType::Int32,
386            true,
387        )]));
388
389        let batch1 = DfRecordBatch::try_new(
390            schema.clone(),
391            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
392        )
393        .unwrap();
394        let batch2 = DfRecordBatch::try_new(
395            schema.clone(),
396            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
397        )
398        .unwrap();
399
400        let flight_data =
401            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
402        assert_eq!(flight_data.len(), 3);
403        let [d1, d2, d3] = flight_data.as_slice() else {
404            unreachable!()
405        };
406
407        let decoder = &mut FlightDecoder::default();
408        assert!(decoder.schema.is_none());
409
410        let result = decoder.try_decode(d2);
411        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
412        assert!(
413            result
414                .unwrap_err()
415                .to_string()
416                .contains("Should have decoded schema first!")
417        );
418
419        let message = decoder.try_decode(d1)?.unwrap();
420        assert!(matches!(message, FlightMessage::Schema(_)));
421        let FlightMessage::Schema(decoded_schema) = message else {
422            unreachable!()
423        };
424        assert_eq!(decoded_schema, schema);
425
426        let _ = decoder.schema.as_ref().unwrap();
427
428        let message = decoder.try_decode(d2)?.unwrap();
429        assert!(matches!(message, FlightMessage::RecordBatch(_)));
430        let FlightMessage::RecordBatch(actual_batch) = message else {
431            unreachable!()
432        };
433        assert_eq!(actual_batch, batch1);
434
435        let message = decoder.try_decode(d3)?.unwrap();
436        assert!(matches!(message, FlightMessage::RecordBatch(_)));
437        let FlightMessage::RecordBatch(actual_batch) = message else {
438            unreachable!()
439        };
440        assert_eq!(actual_batch, batch2);
441        Ok(())
442    }
443
444    #[test]
445    fn test_affected_rows_metrics_encode_decode() -> Result<()> {
446        let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
447        let mut encoder = FlightEncoder::default();
448        let encoded = encoder.encode(FlightMessage::AffectedRows {
449            rows: 3,
450            metrics: Some(metrics.to_string()),
451        });
452
453        assert_eq!(encoded.len(), 1);
454
455        let mut decoder = FlightDecoder::default();
456        let decoded = decoder.try_decode(encoded.first())?.unwrap();
457        let FlightMessage::AffectedRows {
458            rows,
459            metrics: decoded_metrics,
460        } = decoded
461        else {
462            unreachable!()
463        };
464        assert_eq!(rows, 3);
465        assert_eq!(decoded_metrics.as_deref(), Some(metrics));
466
467        let encoded = encoder.encode(FlightMessage::AffectedRows {
468            rows: 5,
469            metrics: None,
470        });
471        let decoded = decoder.try_decode(encoded.first())?.unwrap();
472        let FlightMessage::AffectedRows {
473            rows,
474            metrics: decoded_metrics,
475        } = decoded
476        else {
477            unreachable!()
478        };
479        assert_eq!(rows, 5);
480        assert!(decoded_metrics.is_none());
481
482        Ok(())
483    }
484
485    #[test]
486    fn test_flight_messages_to_recordbatches() {
487        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
488        let batch1 = DfRecordBatch::try_new(
489            schema.clone(),
490            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
491        )
492        .unwrap();
493        let batch2 = DfRecordBatch::try_new(
494            schema.clone(),
495            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
496        )
497        .unwrap();
498        let recordbatches = vec![batch1.clone(), batch2.clone()];
499
500        let m1 = FlightMessage::Schema(schema);
501        let m2 = FlightMessage::RecordBatch(batch1);
502        let m3 = FlightMessage::RecordBatch(batch2);
503
504        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
505        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
506        assert!(
507            result
508                .unwrap_err()
509                .to_string()
510                .contains("First Flight Message must be schema!")
511        );
512
513        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
514        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
515        assert!(
516            result
517                .unwrap_err()
518                .to_string()
519                .contains("Expect the following Flight Messages are all Recordbatches!")
520        );
521
522        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
523        assert_eq!(actual, recordbatches);
524    }
525
526    #[test]
527    fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
528        let schema = Arc::new(Schema::new(vec![
529            Field::new("i", DataType::UInt8, true),
530            Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
531        ]));
532        let batch1 = DfRecordBatch::try_new(
533            schema.clone(),
534            vec![
535                Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
536                Arc::new(DictionaryArray::new(
537                    UInt32Array::from_value(0, 3),
538                    Arc::new(StringArray::from_iter_values(["x"])),
539                )) as _,
540            ],
541        )
542        .unwrap();
543        let batch2 = DfRecordBatch::try_new(
544            schema.clone(),
545            vec![
546                Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
547                Arc::new(DictionaryArray::new(
548                    UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
549                    Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
550                )) as _,
551            ],
552        )
553        .unwrap();
554
555        let message_1 = FlightMessage::Schema(schema.clone());
556        let message_2 = FlightMessage::RecordBatch(batch1);
557        let message_3 = FlightMessage::RecordBatch(batch2);
558
559        let mut encoder = FlightEncoder::default();
560        let encoded_1 = encoder.encode(message_1);
561        let encoded_2 = encoder.encode(message_2);
562        let encoded_3 = encoder.encode(message_3);
563        // message 1 is Arrow Schema, should be encoded to one FlightData:
564        assert_eq!(encoded_1.len(), 1);
565        // message 2 and 3 are Arrow RecordBatch with dictionary arrays, should be encoded to
566        // multiple FlightData:
567        assert_eq!(encoded_2.len(), 2);
568        assert_eq!(encoded_3.len(), 2);
569
570        let mut decoder = FlightDecoder::default();
571        let decoded_1 = decoder.try_decode(encoded_1.first())?;
572        let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
573            unreachable!()
574        };
575        assert_eq!(actual_schema, schema);
576        let decoded_2 = decoder.try_decode(&encoded_2[0])?;
577        // expected to be a dictionary batch message, decoder should return none:
578        assert!(decoded_2.is_none());
579        let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
580            unreachable!()
581        };
582        let decoded_3 = decoder.try_decode(&encoded_3[0])?;
583        // expected to be a dictionary batch message, decoder should return none:
584        assert!(decoded_3.is_none());
585        let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
586            unreachable!()
587        };
588        let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
589            .unwrap()
590            .to_string();
591        let expected = r"
592+---+---+
593| i | s |
594+---+---+
595| 1 | x |
596| 2 | x |
597| 3 | x |
598| 4 | h |
599| 5 | e |
600| 6 | l |
601| 7 | l |
602| 8 | o |
603+---+---+";
604        assert_eq!(actual, expected.trim());
605        Ok(())
606    }
607
608    #[test]
609    fn test_affected_rows_roundtrip_through_flight_codec() {
610        // Verify the full FlightEncoder → FlightDecoder pipeline handles
611        // the new FlightMessage::AffectedRows variant with optional inline
612        // metrics without breaking the wire protocol.
613        let mut encoder = FlightEncoder::default();
614        let mut decoder = FlightDecoder::default();
615
616        // Without metrics — same wire format as old `AffectedRows(7)`.
617        let encoded = encoder.encode(FlightMessage::AffectedRows {
618            rows: 7,
619            metrics: None,
620        });
621        let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
622        assert!(matches!(
623            decoded,
624            FlightMessage::AffectedRows {
625                rows: 7,
626                metrics: None,
627            }
628        ));
629
630        // With metrics — new capability, row count preserved.
631        let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#;
632        let encoded = encoder.encode(FlightMessage::AffectedRows {
633            rows: 42,
634            metrics: Some(json.to_string()),
635        });
636        let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
637        assert!(matches!(
638            decoded,
639            FlightMessage::AffectedRows {
640                rows: 42,
641                metrics: Some(_),
642            }
643        ));
644    }
645
646    /// Simulates the wire output of the **old** `FlightMessage::AffectedRows(usize)`
647    /// variant and verifies that the **new** `FlightDecoder` handles it.
648    #[test]
649    fn test_old_affected_rows_format_decoded_by_new_code() {
650        use arrow_flight::FlightData;
651        use prost::bytes::Bytes as ProstBytes;
652
653        // The old encoder produced FlightData whose app_metadata is
654        // FlightMetadata { affected_rows, metrics: None }. The new
655        // `AffectedRows { rows, metrics: Option<String> }` variant with
656        // `metrics: None` produces the exact same wire bytes.
657        let old_wire_bytes = FlightData {
658            flight_descriptor: None,
659            data_header: build_none_flight_msg().into(),
660            app_metadata: FlightMetadata {
661                affected_rows: Some(AffectedRows { value: 99 }),
662                metrics: None, // old format: no metrics field
663            }
664            .encode_to_vec()
665            .into(),
666            data_body: ProstBytes::default(),
667        };
668
669        let mut decoder = FlightDecoder::default();
670        let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap();
671        assert!(matches!(
672            decoded,
673            FlightMessage::AffectedRows {
674                rows: 99,
675                metrics: None,
676            }
677        ));
678    }
679}