common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::buffer::Buffer;
27use datatypes::arrow::datatypes::{Schema as ArrowSchema, SchemaRef};
28use datatypes::arrow::error::ArrowError;
29use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
30use flatbuffers::FlatBufferBuilder;
31use prost::bytes::Bytes as ProstBytes;
32use prost::Message;
33use snafu::{OptionExt, ResultExt};
34
35use crate::error;
36use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
37
38#[derive(Debug, Clone)]
39pub enum FlightMessage {
40    Schema(SchemaRef),
41    RecordBatch(DfRecordBatch),
42    AffectedRows(usize),
43    Metrics(String),
44}
45
46pub struct FlightEncoder {
47    write_options: writer::IpcWriteOptions,
48    data_gen: writer::IpcDataGenerator,
49    dictionary_tracker: writer::DictionaryTracker,
50}
51
52impl Default for FlightEncoder {
53    fn default() -> Self {
54        let write_options = writer::IpcWriteOptions::default()
55            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
56            .unwrap();
57
58        Self {
59            write_options,
60            data_gen: writer::IpcDataGenerator::default(),
61            dictionary_tracker: writer::DictionaryTracker::new(false),
62        }
63    }
64}
65
66impl FlightEncoder {
67    /// Creates new [FlightEncoder] with compression disabled.
68    pub fn with_compression_disabled() -> Self {
69        let write_options = writer::IpcWriteOptions::default()
70            .try_with_compression(None)
71            .unwrap();
72
73        Self {
74            write_options,
75            data_gen: writer::IpcDataGenerator::default(),
76            dictionary_tracker: writer::DictionaryTracker::new(false),
77        }
78    }
79
80    pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
81        match flight_message {
82            FlightMessage::Schema(schema) => SchemaAsIpc::new(&schema, &self.write_options).into(),
83            FlightMessage::RecordBatch(record_batch) => {
84                let (encoded_dictionaries, encoded_batch) = self
85                    .data_gen
86                    .encoded_batch(
87                        &record_batch,
88                        &mut self.dictionary_tracker,
89                        &self.write_options,
90                    )
91                    .expect("DictionaryTracker configured above to not fail on replacement");
92
93                // TODO(LFC): Handle dictionary as FlightData here, when we supported Arrow's Dictionary DataType.
94                // Currently we don't have a datatype corresponding to Arrow's Dictionary DataType,
95                // so there won't be any "dictionaries" here. Assert to be sure about it, and
96                // perform a "testing guard" in case we forgot to handle the possible "dictionaries"
97                // here in the future.
98                debug_assert_eq!(encoded_dictionaries.len(), 0);
99
100                encoded_batch.into()
101            }
102            FlightMessage::AffectedRows(rows) => {
103                let metadata = FlightMetadata {
104                    affected_rows: Some(AffectedRows { value: rows as _ }),
105                    metrics: None,
106                }
107                .encode_to_vec();
108                FlightData {
109                    flight_descriptor: None,
110                    data_header: build_none_flight_msg().into(),
111                    app_metadata: metadata.into(),
112                    data_body: ProstBytes::default(),
113                }
114            }
115            FlightMessage::Metrics(s) => {
116                let metadata = FlightMetadata {
117                    affected_rows: None,
118                    metrics: Some(Metrics {
119                        metrics: s.as_bytes().to_vec(),
120                    }),
121                }
122                .encode_to_vec();
123                FlightData {
124                    flight_descriptor: None,
125                    data_header: build_none_flight_msg().into(),
126                    app_metadata: metadata.into(),
127                    data_body: ProstBytes::default(),
128                }
129            }
130        }
131    }
132}
133
134#[derive(Default)]
135pub struct FlightDecoder {
136    schema: Option<SchemaRef>,
137    schema_bytes: Option<bytes::Bytes>,
138}
139
140impl FlightDecoder {
141    /// Build a [FlightDecoder] instance from provided schema bytes.
142    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
143        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
144            .context(error::ArrowSnafu)?;
145        Ok(Self {
146            schema: Some(Arc::new(arrow_schema)),
147            schema_bytes: Some(schema_bytes.clone()),
148        })
149    }
150
151    pub fn try_decode_record_batch(
152        &mut self,
153        data_header: &bytes::Bytes,
154        data_body: &bytes::Bytes,
155    ) -> Result<DfRecordBatch> {
156        let schema = self
157            .schema
158            .as_ref()
159            .context(InvalidFlightDataSnafu {
160                reason: "Should have decoded schema first!",
161            })?
162            .clone();
163        let message = root_as_message(&data_header[..])
164            .map_err(|err| {
165                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
166            })
167            .context(error::ArrowSnafu)?;
168        let result = message
169            .header_as_record_batch()
170            .ok_or_else(|| {
171                ArrowError::ParseError(
172                    "Unable to convert flight data header to a record batch".to_string(),
173                )
174            })
175            .and_then(|batch| {
176                reader::read_record_batch(
177                    &Buffer::from(data_body.as_ref()),
178                    batch,
179                    schema,
180                    &HashMap::new(),
181                    None,
182                    &message.version(),
183                )
184            })
185            .context(error::ArrowSnafu)?;
186        Ok(result)
187    }
188
189    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
190        let message = root_as_message(&flight_data.data_header).map_err(|e| {
191            InvalidFlightDataSnafu {
192                reason: e.to_string(),
193            }
194            .build()
195        })?;
196        match message.header_type() {
197            MessageHeader::NONE => {
198                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
199                    .context(DecodeFlightDataSnafu)?;
200                if let Some(AffectedRows { value }) = metadata.affected_rows {
201                    return Ok(FlightMessage::AffectedRows(value as _));
202                }
203                if let Some(Metrics { metrics }) = metadata.metrics {
204                    return Ok(FlightMessage::Metrics(
205                        String::from_utf8_lossy(&metrics).to_string(),
206                    ));
207                }
208                InvalidFlightDataSnafu {
209                    reason: "Expecting FlightMetadata have some meaningful content.",
210                }
211                .fail()
212            }
213            MessageHeader::Schema => {
214                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
215                    InvalidFlightDataSnafu {
216                        reason: e.to_string(),
217                    }
218                    .build()
219                })?);
220                self.schema = Some(arrow_schema.clone());
221                self.schema_bytes = Some(flight_data.data_header.clone());
222                Ok(FlightMessage::Schema(arrow_schema))
223            }
224            MessageHeader::RecordBatch => {
225                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
226                    reason: "Should have decoded schema first!",
227                })?;
228                let arrow_batch =
229                    flight_data_to_arrow_batch(flight_data, schema.clone(), &HashMap::new())
230                        .map_err(|e| {
231                            InvalidFlightDataSnafu {
232                                reason: e.to_string(),
233                            }
234                            .build()
235                        })?;
236                Ok(FlightMessage::RecordBatch(arrow_batch))
237            }
238            other => {
239                let name = other.variant_name().unwrap_or("UNKNOWN");
240                InvalidFlightDataSnafu {
241                    reason: format!("Unsupported FlightData type: {name}"),
242                }
243                .fail()
244            }
245        }
246    }
247
248    pub fn schema(&self) -> Option<&SchemaRef> {
249        self.schema.as_ref()
250    }
251
252    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
253        self.schema_bytes.clone()
254    }
255}
256
257pub fn flight_messages_to_recordbatches(
258    messages: Vec<FlightMessage>,
259) -> Result<Vec<DfRecordBatch>> {
260    if messages.is_empty() {
261        Ok(vec![])
262    } else {
263        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
264
265        match &messages[0] {
266            FlightMessage::Schema(_schema) => {}
267            _ => {
268                return InvalidFlightDataSnafu {
269                    reason: "First Flight Message must be schema!",
270                }
271                .fail()
272            }
273        };
274
275        for message in messages.into_iter().skip(1) {
276            match message {
277                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
278                _ => {
279                    return InvalidFlightDataSnafu {
280                        reason: "Expect the following Flight Messages are all Recordbatches!",
281                    }
282                    .fail()
283                }
284            }
285        }
286
287        Ok(recordbatches)
288    }
289}
290
291fn build_none_flight_msg() -> Bytes {
292    let mut builder = FlatBufferBuilder::new();
293
294    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
295    message.add_version(arrow::ipc::MetadataVersion::V5);
296    message.add_header_type(MessageHeader::NONE);
297    message.add_bodyLength(0);
298
299    let data = message.finish();
300    builder.finish(data, None);
301
302    builder.finished_data().into()
303}
304
305#[cfg(test)]
306mod test {
307    use arrow_flight::utils::batches_to_flight_data;
308    use datatypes::arrow::array::Int32Array;
309    use datatypes::arrow::datatypes::{DataType, Field, Schema};
310
311    use super::*;
312    use crate::Error;
313
314    #[test]
315    fn test_try_decode() {
316        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
317            "n",
318            DataType::Int32,
319            true,
320        )]));
321
322        let batch1 = DfRecordBatch::try_new(
323            schema.clone(),
324            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
325        )
326        .unwrap();
327        let batch2 = DfRecordBatch::try_new(
328            schema.clone(),
329            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
330        )
331        .unwrap();
332
333        let flight_data =
334            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
335        assert_eq!(flight_data.len(), 3);
336        let [d1, d2, d3] = flight_data.as_slice() else {
337            unreachable!()
338        };
339
340        let decoder = &mut FlightDecoder::default();
341        assert!(decoder.schema.is_none());
342
343        let result = decoder.try_decode(d2);
344        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
345        assert!(result
346            .unwrap_err()
347            .to_string()
348            .contains("Should have decoded schema first!"));
349
350        let message = decoder.try_decode(d1).unwrap();
351        assert!(matches!(message, FlightMessage::Schema(_)));
352        let FlightMessage::Schema(decoded_schema) = message else {
353            unreachable!()
354        };
355        assert_eq!(decoded_schema, schema);
356
357        let _ = decoder.schema.as_ref().unwrap();
358
359        let message = decoder.try_decode(d2).unwrap();
360        assert!(matches!(message, FlightMessage::RecordBatch(_)));
361        let FlightMessage::RecordBatch(actual_batch) = message else {
362            unreachable!()
363        };
364        assert_eq!(actual_batch, batch1);
365
366        let message = decoder.try_decode(d3).unwrap();
367        assert!(matches!(message, FlightMessage::RecordBatch(_)));
368        let FlightMessage::RecordBatch(actual_batch) = message else {
369            unreachable!()
370        };
371        assert_eq!(actual_batch, batch2);
372    }
373
374    #[test]
375    fn test_flight_messages_to_recordbatches() {
376        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
377        let batch1 = DfRecordBatch::try_new(
378            schema.clone(),
379            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
380        )
381        .unwrap();
382        let batch2 = DfRecordBatch::try_new(
383            schema.clone(),
384            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
385        )
386        .unwrap();
387        let recordbatches = vec![batch1.clone(), batch2.clone()];
388
389        let m1 = FlightMessage::Schema(schema);
390        let m2 = FlightMessage::RecordBatch(batch1);
391        let m3 = FlightMessage::RecordBatch(batch2);
392
393        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
394        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
395        assert!(result
396            .unwrap_err()
397            .to_string()
398            .contains("First Flight Message must be schema!"));
399
400        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
401        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
402        assert!(result
403            .unwrap_err()
404            .to_string()
405            .contains("Expect the following Flight Messages are all Recordbatches!"));
406
407        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
408        assert_eq!(actual, recordbatches);
409    }
410}