common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::buffer::Buffer;
27use datatypes::arrow::datatypes::{Schema as ArrowSchema, SchemaRef};
28use datatypes::arrow::error::ArrowError;
29use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
30use flatbuffers::FlatBufferBuilder;
31use prost::bytes::Bytes as ProstBytes;
32use prost::Message;
33use snafu::{OptionExt, ResultExt};
34
35use crate::error;
36use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
37
38#[derive(Debug, Clone)]
39pub enum FlightMessage {
40    Schema(SchemaRef),
41    RecordBatch(DfRecordBatch),
42    AffectedRows(usize),
43    Metrics(String),
44}
45
46pub struct FlightEncoder {
47    write_options: writer::IpcWriteOptions,
48    data_gen: writer::IpcDataGenerator,
49    dictionary_tracker: writer::DictionaryTracker,
50}
51
52impl Default for FlightEncoder {
53    fn default() -> Self {
54        let write_options = writer::IpcWriteOptions::default()
55            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
56            .unwrap();
57
58        Self {
59            write_options,
60            data_gen: writer::IpcDataGenerator::default(),
61            dictionary_tracker: writer::DictionaryTracker::new(false),
62        }
63    }
64}
65
66impl FlightEncoder {
67    pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
68        match flight_message {
69            FlightMessage::Schema(schema) => SchemaAsIpc::new(&schema, &self.write_options).into(),
70            FlightMessage::RecordBatch(record_batch) => {
71                let (encoded_dictionaries, encoded_batch) = self
72                    .data_gen
73                    .encoded_batch(
74                        &record_batch,
75                        &mut self.dictionary_tracker,
76                        &self.write_options,
77                    )
78                    .expect("DictionaryTracker configured above to not fail on replacement");
79
80                // TODO(LFC): Handle dictionary as FlightData here, when we supported Arrow's Dictionary DataType.
81                // Currently we don't have a datatype corresponding to Arrow's Dictionary DataType,
82                // so there won't be any "dictionaries" here. Assert to be sure about it, and
83                // perform a "testing guard" in case we forgot to handle the possible "dictionaries"
84                // here in the future.
85                debug_assert_eq!(encoded_dictionaries.len(), 0);
86
87                encoded_batch.into()
88            }
89            FlightMessage::AffectedRows(rows) => {
90                let metadata = FlightMetadata {
91                    affected_rows: Some(AffectedRows { value: rows as _ }),
92                    metrics: None,
93                }
94                .encode_to_vec();
95                FlightData {
96                    flight_descriptor: None,
97                    data_header: build_none_flight_msg().into(),
98                    app_metadata: metadata.into(),
99                    data_body: ProstBytes::default(),
100                }
101            }
102            FlightMessage::Metrics(s) => {
103                let metadata = FlightMetadata {
104                    affected_rows: None,
105                    metrics: Some(Metrics {
106                        metrics: s.as_bytes().to_vec(),
107                    }),
108                }
109                .encode_to_vec();
110                FlightData {
111                    flight_descriptor: None,
112                    data_header: build_none_flight_msg().into(),
113                    app_metadata: metadata.into(),
114                    data_body: ProstBytes::default(),
115                }
116            }
117        }
118    }
119}
120
121#[derive(Default)]
122pub struct FlightDecoder {
123    schema: Option<SchemaRef>,
124    schema_bytes: Option<bytes::Bytes>,
125}
126
127impl FlightDecoder {
128    /// Build a [FlightDecoder] instance from provided schema bytes.
129    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
130        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
131            .context(error::ArrowSnafu)?;
132        Ok(Self {
133            schema: Some(Arc::new(arrow_schema)),
134            schema_bytes: Some(schema_bytes.clone()),
135        })
136    }
137
138    pub fn try_decode_record_batch(
139        &mut self,
140        data_header: &bytes::Bytes,
141        data_body: &bytes::Bytes,
142    ) -> Result<DfRecordBatch> {
143        let schema = self
144            .schema
145            .as_ref()
146            .context(InvalidFlightDataSnafu {
147                reason: "Should have decoded schema first!",
148            })?
149            .clone();
150        let message = root_as_message(&data_header[..])
151            .map_err(|err| {
152                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
153            })
154            .context(error::ArrowSnafu)?;
155        let result = message
156            .header_as_record_batch()
157            .ok_or_else(|| {
158                ArrowError::ParseError(
159                    "Unable to convert flight data header to a record batch".to_string(),
160                )
161            })
162            .and_then(|batch| {
163                reader::read_record_batch(
164                    &Buffer::from(data_body.as_ref()),
165                    batch,
166                    schema,
167                    &HashMap::new(),
168                    None,
169                    &message.version(),
170                )
171            })
172            .context(error::ArrowSnafu)?;
173        Ok(result)
174    }
175
176    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
177        let message = root_as_message(&flight_data.data_header).map_err(|e| {
178            InvalidFlightDataSnafu {
179                reason: e.to_string(),
180            }
181            .build()
182        })?;
183        match message.header_type() {
184            MessageHeader::NONE => {
185                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
186                    .context(DecodeFlightDataSnafu)?;
187                if let Some(AffectedRows { value }) = metadata.affected_rows {
188                    return Ok(FlightMessage::AffectedRows(value as _));
189                }
190                if let Some(Metrics { metrics }) = metadata.metrics {
191                    return Ok(FlightMessage::Metrics(
192                        String::from_utf8_lossy(&metrics).to_string(),
193                    ));
194                }
195                InvalidFlightDataSnafu {
196                    reason: "Expecting FlightMetadata have some meaningful content.",
197                }
198                .fail()
199            }
200            MessageHeader::Schema => {
201                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
202                    InvalidFlightDataSnafu {
203                        reason: e.to_string(),
204                    }
205                    .build()
206                })?);
207                self.schema = Some(arrow_schema.clone());
208                self.schema_bytes = Some(flight_data.data_header.clone());
209                Ok(FlightMessage::Schema(arrow_schema))
210            }
211            MessageHeader::RecordBatch => {
212                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
213                    reason: "Should have decoded schema first!",
214                })?;
215                let arrow_batch =
216                    flight_data_to_arrow_batch(flight_data, schema.clone(), &HashMap::new())
217                        .map_err(|e| {
218                            InvalidFlightDataSnafu {
219                                reason: e.to_string(),
220                            }
221                            .build()
222                        })?;
223                Ok(FlightMessage::RecordBatch(arrow_batch))
224            }
225            other => {
226                let name = other.variant_name().unwrap_or("UNKNOWN");
227                InvalidFlightDataSnafu {
228                    reason: format!("Unsupported FlightData type: {name}"),
229                }
230                .fail()
231            }
232        }
233    }
234
235    pub fn schema(&self) -> Option<&SchemaRef> {
236        self.schema.as_ref()
237    }
238
239    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
240        self.schema_bytes.clone()
241    }
242}
243
244pub fn flight_messages_to_recordbatches(
245    messages: Vec<FlightMessage>,
246) -> Result<Vec<DfRecordBatch>> {
247    if messages.is_empty() {
248        Ok(vec![])
249    } else {
250        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
251
252        match &messages[0] {
253            FlightMessage::Schema(_schema) => {}
254            _ => {
255                return InvalidFlightDataSnafu {
256                    reason: "First Flight Message must be schema!",
257                }
258                .fail()
259            }
260        };
261
262        for message in messages.into_iter().skip(1) {
263            match message {
264                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
265                _ => {
266                    return InvalidFlightDataSnafu {
267                        reason: "Expect the following Flight Messages are all Recordbatches!",
268                    }
269                    .fail()
270                }
271            }
272        }
273
274        Ok(recordbatches)
275    }
276}
277
278fn build_none_flight_msg() -> Bytes {
279    let mut builder = FlatBufferBuilder::new();
280
281    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
282    message.add_version(arrow::ipc::MetadataVersion::V5);
283    message.add_header_type(MessageHeader::NONE);
284    message.add_bodyLength(0);
285
286    let data = message.finish();
287    builder.finish(data, None);
288
289    builder.finished_data().into()
290}
291
292#[cfg(test)]
293mod test {
294    use arrow_flight::utils::batches_to_flight_data;
295    use datatypes::arrow::array::Int32Array;
296    use datatypes::arrow::datatypes::{DataType, Field, Schema};
297
298    use super::*;
299    use crate::Error;
300
301    #[test]
302    fn test_try_decode() {
303        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
304            "n",
305            DataType::Int32,
306            true,
307        )]));
308
309        let batch1 = DfRecordBatch::try_new(
310            schema.clone(),
311            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
312        )
313        .unwrap();
314        let batch2 = DfRecordBatch::try_new(
315            schema.clone(),
316            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
317        )
318        .unwrap();
319
320        let flight_data =
321            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
322        assert_eq!(flight_data.len(), 3);
323        let [d1, d2, d3] = flight_data.as_slice() else {
324            unreachable!()
325        };
326
327        let decoder = &mut FlightDecoder::default();
328        assert!(decoder.schema.is_none());
329
330        let result = decoder.try_decode(d2);
331        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
332        assert!(result
333            .unwrap_err()
334            .to_string()
335            .contains("Should have decoded schema first!"));
336
337        let message = decoder.try_decode(d1).unwrap();
338        assert!(matches!(message, FlightMessage::Schema(_)));
339        let FlightMessage::Schema(decoded_schema) = message else {
340            unreachable!()
341        };
342        assert_eq!(decoded_schema, schema);
343
344        let _ = decoder.schema.as_ref().unwrap();
345
346        let message = decoder.try_decode(d2).unwrap();
347        assert!(matches!(message, FlightMessage::RecordBatch(_)));
348        let FlightMessage::RecordBatch(actual_batch) = message else {
349            unreachable!()
350        };
351        assert_eq!(actual_batch, batch1);
352
353        let message = decoder.try_decode(d3).unwrap();
354        assert!(matches!(message, FlightMessage::RecordBatch(_)));
355        let FlightMessage::RecordBatch(actual_batch) = message else {
356            unreachable!()
357        };
358        assert_eq!(actual_batch, batch2);
359    }
360
361    #[test]
362    fn test_flight_messages_to_recordbatches() {
363        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
364        let batch1 = DfRecordBatch::try_new(
365            schema.clone(),
366            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
367        )
368        .unwrap();
369        let batch2 = DfRecordBatch::try_new(
370            schema.clone(),
371            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
372        )
373        .unwrap();
374        let recordbatches = vec![batch1.clone(), batch2.clone()];
375
376        let m1 = FlightMessage::Schema(schema);
377        let m2 = FlightMessage::RecordBatch(batch1);
378        let m3 = FlightMessage::RecordBatch(batch2);
379
380        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
381        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
382        assert!(result
383            .unwrap_err()
384            .to_string()
385            .contains("First Flight Message must be schema!"));
386
387        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
388        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
389        assert!(result
390            .unwrap_err()
391            .to_string()
392            .contains("Expect the following Flight Messages are all Recordbatches!"));
393
394        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
395        assert_eq!(actual, recordbatches);
396    }
397}