common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::{RecordBatch, RecordBatches};
25use datatypes::arrow;
26use datatypes::arrow::datatypes::Schema as ArrowSchema;
27use datatypes::arrow::ipc::{root_as_message, writer, MessageHeader};
28use datatypes::schema::{Schema, SchemaRef};
29use flatbuffers::FlatBufferBuilder;
30use prost::bytes::Bytes as ProstBytes;
31use prost::Message;
32use snafu::{OptionExt, ResultExt};
33
34use crate::error::{
35    ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu,
36    Result,
37};
38
39#[derive(Debug, Clone)]
40pub enum FlightMessage {
41    Schema(SchemaRef),
42    Recordbatch(RecordBatch),
43    AffectedRows(usize),
44    Metrics(String),
45}
46
47pub struct FlightEncoder {
48    write_options: writer::IpcWriteOptions,
49    data_gen: writer::IpcDataGenerator,
50    dictionary_tracker: writer::DictionaryTracker,
51}
52
53impl Default for FlightEncoder {
54    fn default() -> Self {
55        let write_options = writer::IpcWriteOptions::default()
56            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
57            .unwrap();
58
59        Self {
60            write_options,
61            data_gen: writer::IpcDataGenerator::default(),
62            dictionary_tracker: writer::DictionaryTracker::new(false),
63        }
64    }
65}
66
67impl FlightEncoder {
68    pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
69        match flight_message {
70            FlightMessage::Schema(schema) => {
71                SchemaAsIpc::new(schema.arrow_schema(), &self.write_options).into()
72            }
73            FlightMessage::Recordbatch(recordbatch) => {
74                let (encoded_dictionaries, encoded_batch) = self
75                    .data_gen
76                    .encoded_batch(
77                        recordbatch.df_record_batch(),
78                        &mut self.dictionary_tracker,
79                        &self.write_options,
80                    )
81                    .expect("DictionaryTracker configured above to not fail on replacement");
82
83                // TODO(LFC): Handle dictionary as FlightData here, when we supported Arrow's Dictionary DataType.
84                // Currently we don't have a datatype corresponding to Arrow's Dictionary DataType,
85                // so there won't be any "dictionaries" here. Assert to be sure about it, and
86                // perform a "testing guard" in case we forgot to handle the possible "dictionaries"
87                // here in the future.
88                debug_assert_eq!(encoded_dictionaries.len(), 0);
89
90                encoded_batch.into()
91            }
92            FlightMessage::AffectedRows(rows) => {
93                let metadata = FlightMetadata {
94                    affected_rows: Some(AffectedRows { value: rows as _ }),
95                    metrics: None,
96                }
97                .encode_to_vec();
98                FlightData {
99                    flight_descriptor: None,
100                    data_header: build_none_flight_msg().into(),
101                    app_metadata: metadata.into(),
102                    data_body: ProstBytes::default(),
103                }
104            }
105            FlightMessage::Metrics(s) => {
106                let metadata = FlightMetadata {
107                    affected_rows: None,
108                    metrics: Some(Metrics {
109                        metrics: s.as_bytes().to_vec(),
110                    }),
111                }
112                .encode_to_vec();
113                FlightData {
114                    flight_descriptor: None,
115                    data_header: build_none_flight_msg().into(),
116                    app_metadata: metadata.into(),
117                    data_body: ProstBytes::default(),
118                }
119            }
120        }
121    }
122}
123
124#[derive(Default)]
125pub struct FlightDecoder {
126    schema: Option<SchemaRef>,
127}
128
129impl FlightDecoder {
130    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
131        let message = root_as_message(&flight_data.data_header).map_err(|e| {
132            InvalidFlightDataSnafu {
133                reason: e.to_string(),
134            }
135            .build()
136        })?;
137        match message.header_type() {
138            MessageHeader::NONE => {
139                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
140                    .context(DecodeFlightDataSnafu)?;
141                if let Some(AffectedRows { value }) = metadata.affected_rows {
142                    return Ok(FlightMessage::AffectedRows(value as _));
143                }
144                if let Some(Metrics { metrics }) = metadata.metrics {
145                    return Ok(FlightMessage::Metrics(
146                        String::from_utf8_lossy(&metrics).to_string(),
147                    ));
148                }
149                InvalidFlightDataSnafu {
150                    reason: "Expecting FlightMetadata have some meaningful content.",
151                }
152                .fail()
153            }
154            MessageHeader::Schema => {
155                let arrow_schema = ArrowSchema::try_from(flight_data).map_err(|e| {
156                    InvalidFlightDataSnafu {
157                        reason: e.to_string(),
158                    }
159                    .build()
160                })?;
161                let schema =
162                    Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?);
163
164                self.schema = Some(schema.clone());
165
166                Ok(FlightMessage::Schema(schema))
167            }
168            MessageHeader::RecordBatch => {
169                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
170                    reason: "Should have decoded schema first!",
171                })?;
172                let arrow_schema = schema.arrow_schema().clone();
173
174                let arrow_batch =
175                    flight_data_to_arrow_batch(flight_data, arrow_schema, &HashMap::new())
176                        .map_err(|e| {
177                            InvalidFlightDataSnafu {
178                                reason: e.to_string(),
179                            }
180                            .build()
181                        })?;
182                let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch)
183                    .context(CreateRecordBatchSnafu)?;
184                Ok(FlightMessage::Recordbatch(recordbatch))
185            }
186            other => {
187                let name = other.variant_name().unwrap_or("UNKNOWN");
188                InvalidFlightDataSnafu {
189                    reason: format!("Unsupported FlightData type: {name}"),
190                }
191                .fail()
192            }
193        }
194    }
195
196    pub fn schema(&self) -> Option<&SchemaRef> {
197        self.schema.as_ref()
198    }
199}
200
201pub fn flight_messages_to_recordbatches(messages: Vec<FlightMessage>) -> Result<RecordBatches> {
202    if messages.is_empty() {
203        Ok(RecordBatches::empty())
204    } else {
205        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
206
207        let schema = match &messages[0] {
208            FlightMessage::Schema(schema) => schema.clone(),
209            _ => {
210                return InvalidFlightDataSnafu {
211                    reason: "First Flight Message must be schema!",
212                }
213                .fail()
214            }
215        };
216
217        for message in messages.into_iter().skip(1) {
218            match message {
219                FlightMessage::Recordbatch(recordbatch) => recordbatches.push(recordbatch),
220                _ => {
221                    return InvalidFlightDataSnafu {
222                        reason: "Expect the following Flight Messages are all Recordbatches!",
223                    }
224                    .fail()
225                }
226            }
227        }
228
229        RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu)
230    }
231}
232
233fn build_none_flight_msg() -> Bytes {
234    let mut builder = FlatBufferBuilder::new();
235
236    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
237    message.add_version(arrow::ipc::MetadataVersion::V5);
238    message.add_header_type(MessageHeader::NONE);
239    message.add_bodyLength(0);
240
241    let data = message.finish();
242    builder.finish(data, None);
243
244    builder.finished_data().into()
245}
246
247#[cfg(test)]
248mod test {
249    use arrow_flight::utils::batches_to_flight_data;
250    use datatypes::arrow::datatypes::{DataType, Field};
251    use datatypes::prelude::ConcreteDataType;
252    use datatypes::schema::ColumnSchema;
253    use datatypes::vectors::Int32Vector;
254
255    use super::*;
256    use crate::Error;
257
258    #[test]
259    fn test_try_decode() {
260        let arrow_schema = ArrowSchema::new(vec![Field::new("n", DataType::Int32, true)]);
261        let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap());
262
263        let batch1 = RecordBatch::new(
264            schema.clone(),
265            vec![Arc::new(Int32Vector::from(vec![Some(1), None, Some(3)])) as _],
266        )
267        .unwrap();
268        let batch2 = RecordBatch::new(
269            schema.clone(),
270            vec![Arc::new(Int32Vector::from(vec![None, Some(5)])) as _],
271        )
272        .unwrap();
273
274        let flight_data = batches_to_flight_data(
275            &arrow_schema,
276            vec![
277                batch1.clone().into_df_record_batch(),
278                batch2.clone().into_df_record_batch(),
279            ],
280        )
281        .unwrap();
282        assert_eq!(flight_data.len(), 3);
283        let [d1, d2, d3] = flight_data.as_slice() else {
284            unreachable!()
285        };
286
287        let decoder = &mut FlightDecoder::default();
288        assert!(decoder.schema.is_none());
289
290        let result = decoder.try_decode(d2);
291        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
292        assert!(result
293            .unwrap_err()
294            .to_string()
295            .contains("Should have decoded schema first!"));
296
297        let message = decoder.try_decode(d1).unwrap();
298        assert!(matches!(message, FlightMessage::Schema(_)));
299        let FlightMessage::Schema(decoded_schema) = message else {
300            unreachable!()
301        };
302        assert_eq!(decoded_schema, schema);
303
304        let _ = decoder.schema.as_ref().unwrap();
305
306        let message = decoder.try_decode(d2).unwrap();
307        assert!(matches!(message, FlightMessage::Recordbatch(_)));
308        let FlightMessage::Recordbatch(actual_batch) = message else {
309            unreachable!()
310        };
311        assert_eq!(actual_batch, batch1);
312
313        let message = decoder.try_decode(d3).unwrap();
314        assert!(matches!(message, FlightMessage::Recordbatch(_)));
315        let FlightMessage::Recordbatch(actual_batch) = message else {
316            unreachable!()
317        };
318        assert_eq!(actual_batch, batch2);
319    }
320
321    #[test]
322    fn test_flight_messages_to_recordbatches() {
323        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
324            "m",
325            ConcreteDataType::int32_datatype(),
326            true,
327        )]));
328        let batch1 = RecordBatch::new(
329            schema.clone(),
330            vec![Arc::new(Int32Vector::from(vec![Some(2), None, Some(4)])) as _],
331        )
332        .unwrap();
333        let batch2 = RecordBatch::new(
334            schema.clone(),
335            vec![Arc::new(Int32Vector::from(vec![None, Some(6)])) as _],
336        )
337        .unwrap();
338        let recordbatches =
339            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
340
341        let m1 = FlightMessage::Schema(schema);
342        let m2 = FlightMessage::Recordbatch(batch1);
343        let m3 = FlightMessage::Recordbatch(batch2);
344
345        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
346        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
347        assert!(result
348            .unwrap_err()
349            .to_string()
350            .contains("First Flight Message must be schema!"));
351
352        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
353        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
354        assert!(result
355            .unwrap_err()
356            .to_string()
357            .contains("Expect the following Flight Messages are all Recordbatches!"));
358
359        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
360        assert_eq!(actual, recordbatches);
361    }
362}