1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
31use flatbuffers::FlatBufferBuilder;
32use prost::bytes::Bytes as ProstBytes;
33use prost::Message;
34use snafu::{OptionExt, ResultExt};
35use vec1::{vec1, Vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40#[derive(Debug, Clone)]
41pub enum FlightMessage {
42 Schema(SchemaRef),
43 RecordBatch(DfRecordBatch),
44 AffectedRows(usize),
45 Metrics(String),
46}
47
48pub struct FlightEncoder {
49 write_options: writer::IpcWriteOptions,
50 data_gen: writer::IpcDataGenerator,
51 dictionary_tracker: writer::DictionaryTracker,
52}
53
54impl Default for FlightEncoder {
55 fn default() -> Self {
56 let write_options = writer::IpcWriteOptions::default()
57 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
58 .unwrap();
59
60 Self {
61 write_options,
62 data_gen: writer::IpcDataGenerator::default(),
63 dictionary_tracker: writer::DictionaryTracker::new(false),
64 }
65 }
66}
67
68impl FlightEncoder {
69 pub fn with_compression_disabled() -> Self {
71 let write_options = writer::IpcWriteOptions::default()
72 .try_with_compression(None)
73 .unwrap();
74
75 Self {
76 write_options,
77 data_gen: writer::IpcDataGenerator::default(),
78 dictionary_tracker: writer::DictionaryTracker::new(false),
79 }
80 }
81
82 pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
84 SchemaAsIpc::new(schema, &self.write_options).into()
85 }
86
87 pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
93 match flight_message {
94 FlightMessage::Schema(schema) => {
95 schema.fields().iter().for_each(|x| {
96 if matches!(x.data_type(), DataType::Dictionary(_, _)) {
97 self.dictionary_tracker.next_dict_id();
98 }
99 });
100
101 vec1![self.encode_schema(schema.as_ref())]
102 }
103 FlightMessage::RecordBatch(record_batch) => {
104 let (encoded_dictionaries, encoded_batch) = self
105 .data_gen
106 .encoded_batch(
107 &record_batch,
108 &mut self.dictionary_tracker,
109 &self.write_options,
110 )
111 .expect("DictionaryTracker configured above to not fail on replacement");
112
113 Vec1::from_vec_push(
114 encoded_dictionaries.into_iter().map(Into::into).collect(),
115 encoded_batch.into(),
116 )
117 }
118 FlightMessage::AffectedRows(rows) => {
119 let metadata = FlightMetadata {
120 affected_rows: Some(AffectedRows { value: rows as _ }),
121 metrics: None,
122 }
123 .encode_to_vec();
124 vec1![FlightData {
125 flight_descriptor: None,
126 data_header: build_none_flight_msg().into(),
127 app_metadata: metadata.into(),
128 data_body: ProstBytes::default(),
129 }]
130 }
131 FlightMessage::Metrics(s) => {
132 let metadata = FlightMetadata {
133 affected_rows: None,
134 metrics: Some(Metrics {
135 metrics: s.as_bytes().to_vec(),
136 }),
137 }
138 .encode_to_vec();
139 vec1![FlightData {
140 flight_descriptor: None,
141 data_header: build_none_flight_msg().into(),
142 app_metadata: metadata.into(),
143 data_body: ProstBytes::default(),
144 }]
145 }
146 }
147 }
148}
149
150#[derive(Default)]
151pub struct FlightDecoder {
152 schema: Option<SchemaRef>,
153 schema_bytes: Option<bytes::Bytes>,
154 dictionaries_by_id: HashMap<i64, ArrayRef>,
155}
156
157impl FlightDecoder {
158 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
160 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
161 .context(error::ArrowSnafu)?;
162 Ok(Self {
163 schema: Some(Arc::new(arrow_schema)),
164 schema_bytes: Some(schema_bytes.clone()),
165 dictionaries_by_id: HashMap::new(),
166 })
167 }
168
169 pub fn try_decode_record_batch(
170 &mut self,
171 data_header: &bytes::Bytes,
172 data_body: &bytes::Bytes,
173 ) -> Result<DfRecordBatch> {
174 let schema = self
175 .schema
176 .as_ref()
177 .context(InvalidFlightDataSnafu {
178 reason: "Should have decoded schema first!",
179 })?
180 .clone();
181 let message = root_as_message(&data_header[..])
182 .map_err(|err| {
183 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
184 })
185 .context(error::ArrowSnafu)?;
186 let result = message
187 .header_as_record_batch()
188 .ok_or_else(|| {
189 ArrowError::ParseError(
190 "Unable to convert flight data header to a record batch".to_string(),
191 )
192 })
193 .and_then(|batch| {
194 reader::read_record_batch(
195 &Buffer::from(data_body.as_ref()),
196 batch,
197 schema,
198 &HashMap::new(),
199 None,
200 &message.version(),
201 )
202 })
203 .context(error::ArrowSnafu)?;
204 Ok(result)
205 }
206
207 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
214 let message = root_as_message(&flight_data.data_header).map_err(|e| {
215 InvalidFlightDataSnafu {
216 reason: e.to_string(),
217 }
218 .build()
219 })?;
220 match message.header_type() {
221 MessageHeader::NONE => {
222 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
223 .context(DecodeFlightDataSnafu)?;
224 if let Some(AffectedRows { value }) = metadata.affected_rows {
225 return Ok(Some(FlightMessage::AffectedRows(value as _)));
226 }
227 if let Some(Metrics { metrics }) = metadata.metrics {
228 return Ok(Some(FlightMessage::Metrics(
229 String::from_utf8_lossy(&metrics).to_string(),
230 )));
231 }
232 InvalidFlightDataSnafu {
233 reason: "Expecting FlightMetadata have some meaningful content.",
234 }
235 .fail()
236 }
237 MessageHeader::Schema => {
238 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
239 InvalidFlightDataSnafu {
240 reason: e.to_string(),
241 }
242 .build()
243 })?);
244 self.schema = Some(arrow_schema.clone());
245 self.schema_bytes = Some(flight_data.data_header.clone());
246 Ok(Some(FlightMessage::Schema(arrow_schema)))
247 }
248 MessageHeader::RecordBatch => {
249 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
250 reason: "Should have decoded schema first!",
251 })?;
252 let arrow_batch = flight_data_to_arrow_batch(
253 flight_data,
254 schema.clone(),
255 &self.dictionaries_by_id,
256 )
257 .map_err(|e| {
258 InvalidFlightDataSnafu {
259 reason: e.to_string(),
260 }
261 .build()
262 })?;
263 Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
264 }
265 MessageHeader::DictionaryBatch => {
266 let dictionary_batch =
267 message
268 .header_as_dictionary_batch()
269 .context(InvalidFlightDataSnafu {
270 reason: "could not get dictionary batch from DictionaryBatch message",
271 })?;
272
273 let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
274 reason: "schema message is not present previously",
275 })?;
276
277 reader::read_dictionary(
278 &flight_data.data_body.clone().into(),
279 dictionary_batch,
280 schema,
281 &mut self.dictionaries_by_id,
282 &message.version(),
283 )
284 .context(error::ArrowSnafu)?;
285 Ok(None)
286 }
287 other => {
288 let name = other.variant_name().unwrap_or("UNKNOWN");
289 InvalidFlightDataSnafu {
290 reason: format!("Unsupported FlightData type: {name}"),
291 }
292 .fail()
293 }
294 }
295 }
296
297 pub fn schema(&self) -> Option<&SchemaRef> {
298 self.schema.as_ref()
299 }
300
301 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
302 self.schema_bytes.clone()
303 }
304}
305
306pub fn flight_messages_to_recordbatches(
307 messages: Vec<FlightMessage>,
308) -> Result<Vec<DfRecordBatch>> {
309 if messages.is_empty() {
310 Ok(vec![])
311 } else {
312 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
313
314 match &messages[0] {
315 FlightMessage::Schema(_schema) => {}
316 _ => {
317 return InvalidFlightDataSnafu {
318 reason: "First Flight Message must be schema!",
319 }
320 .fail()
321 }
322 };
323
324 for message in messages.into_iter().skip(1) {
325 match message {
326 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
327 _ => {
328 return InvalidFlightDataSnafu {
329 reason: "Expect the following Flight Messages are all Recordbatches!",
330 }
331 .fail()
332 }
333 }
334 }
335
336 Ok(recordbatches)
337 }
338}
339
340fn build_none_flight_msg() -> Bytes {
341 let mut builder = FlatBufferBuilder::new();
342
343 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
344 message.add_version(arrow::ipc::MetadataVersion::V5);
345 message.add_header_type(MessageHeader::NONE);
346 message.add_bodyLength(0);
347
348 let data = message.finish();
349 builder.finish(data, None);
350
351 builder.finished_data().into()
352}
353
354#[cfg(test)]
355mod test {
356 use arrow_flight::utils::batches_to_flight_data;
357 use datatypes::arrow::array::{
358 DictionaryArray, Int32Array, StringArray, UInt32Array, UInt8Array,
359 };
360 use datatypes::arrow::datatypes::{DataType, Field, Schema};
361
362 use super::*;
363 use crate::Error;
364
365 #[test]
366 fn test_try_decode() -> Result<()> {
367 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
368 "n",
369 DataType::Int32,
370 true,
371 )]));
372
373 let batch1 = DfRecordBatch::try_new(
374 schema.clone(),
375 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
376 )
377 .unwrap();
378 let batch2 = DfRecordBatch::try_new(
379 schema.clone(),
380 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
381 )
382 .unwrap();
383
384 let flight_data =
385 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
386 assert_eq!(flight_data.len(), 3);
387 let [d1, d2, d3] = flight_data.as_slice() else {
388 unreachable!()
389 };
390
391 let decoder = &mut FlightDecoder::default();
392 assert!(decoder.schema.is_none());
393
394 let result = decoder.try_decode(d2);
395 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
396 assert!(result
397 .unwrap_err()
398 .to_string()
399 .contains("Should have decoded schema first!"));
400
401 let message = decoder.try_decode(d1)?.unwrap();
402 assert!(matches!(message, FlightMessage::Schema(_)));
403 let FlightMessage::Schema(decoded_schema) = message else {
404 unreachable!()
405 };
406 assert_eq!(decoded_schema, schema);
407
408 let _ = decoder.schema.as_ref().unwrap();
409
410 let message = decoder.try_decode(d2)?.unwrap();
411 assert!(matches!(message, FlightMessage::RecordBatch(_)));
412 let FlightMessage::RecordBatch(actual_batch) = message else {
413 unreachable!()
414 };
415 assert_eq!(actual_batch, batch1);
416
417 let message = decoder.try_decode(d3)?.unwrap();
418 assert!(matches!(message, FlightMessage::RecordBatch(_)));
419 let FlightMessage::RecordBatch(actual_batch) = message else {
420 unreachable!()
421 };
422 assert_eq!(actual_batch, batch2);
423 Ok(())
424 }
425
426 #[test]
427 fn test_flight_messages_to_recordbatches() {
428 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
429 let batch1 = DfRecordBatch::try_new(
430 schema.clone(),
431 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
432 )
433 .unwrap();
434 let batch2 = DfRecordBatch::try_new(
435 schema.clone(),
436 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
437 )
438 .unwrap();
439 let recordbatches = vec![batch1.clone(), batch2.clone()];
440
441 let m1 = FlightMessage::Schema(schema);
442 let m2 = FlightMessage::RecordBatch(batch1);
443 let m3 = FlightMessage::RecordBatch(batch2);
444
445 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
446 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
447 assert!(result
448 .unwrap_err()
449 .to_string()
450 .contains("First Flight Message must be schema!"));
451
452 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
453 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
454 assert!(result
455 .unwrap_err()
456 .to_string()
457 .contains("Expect the following Flight Messages are all Recordbatches!"));
458
459 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
460 assert_eq!(actual, recordbatches);
461 }
462
463 #[test]
464 fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
465 let schema = Arc::new(Schema::new(vec![
466 Field::new("i", DataType::UInt8, true),
467 Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
468 ]));
469 let batch1 = DfRecordBatch::try_new(
470 schema.clone(),
471 vec![
472 Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
473 Arc::new(DictionaryArray::new(
474 UInt32Array::from_value(0, 3),
475 Arc::new(StringArray::from_iter_values(["x"])),
476 )) as _,
477 ],
478 )
479 .unwrap();
480 let batch2 = DfRecordBatch::try_new(
481 schema.clone(),
482 vec![
483 Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
484 Arc::new(DictionaryArray::new(
485 UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
486 Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
487 )) as _,
488 ],
489 )
490 .unwrap();
491
492 let message_1 = FlightMessage::Schema(schema.clone());
493 let message_2 = FlightMessage::RecordBatch(batch1);
494 let message_3 = FlightMessage::RecordBatch(batch2);
495
496 let mut encoder = FlightEncoder::default();
497 let encoded_1 = encoder.encode(message_1);
498 let encoded_2 = encoder.encode(message_2);
499 let encoded_3 = encoder.encode(message_3);
500 assert_eq!(encoded_1.len(), 1);
502 assert_eq!(encoded_2.len(), 2);
505 assert_eq!(encoded_3.len(), 2);
506
507 let mut decoder = FlightDecoder::default();
508 let decoded_1 = decoder.try_decode(encoded_1.first())?;
509 let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
510 unreachable!()
511 };
512 assert_eq!(actual_schema, schema);
513 let decoded_2 = decoder.try_decode(&encoded_2[0])?;
514 assert!(decoded_2.is_none());
516 let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
517 unreachable!()
518 };
519 let decoded_3 = decoder.try_decode(&encoded_3[0])?;
520 assert!(decoded_3.is_none());
522 let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
523 unreachable!()
524 };
525 let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
526 .unwrap()
527 .to_string();
528 let expected = r"
529+---+---+
530| i | s |
531+---+---+
532| 1 | x |
533| 2 | x |
534| 3 | x |
535| 4 | h |
536| 5 | e |
537| 6 | l |
538| 7 | l |
539| 8 | o |
540+---+---+";
541 assert_eq!(actual, expected.trim());
542 Ok(())
543 }
544}