1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40#[derive(Debug, Clone)]
41pub enum FlightMessage {
42 Schema(SchemaRef),
43 RecordBatch(DfRecordBatch),
44 AffectedRows(usize),
45 Metrics(String),
46}
47
48pub struct FlightEncoder {
49 write_options: writer::IpcWriteOptions,
50 data_gen: writer::IpcDataGenerator,
51 dictionary_tracker: writer::DictionaryTracker,
52}
53
54impl Default for FlightEncoder {
55 fn default() -> Self {
56 let write_options = writer::IpcWriteOptions::default()
57 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
58 .unwrap();
59
60 Self {
61 write_options,
62 data_gen: writer::IpcDataGenerator::default(),
63 dictionary_tracker: writer::DictionaryTracker::new(false),
64 }
65 }
66}
67
68impl FlightEncoder {
69 pub fn with_compression_disabled() -> Self {
71 let write_options = writer::IpcWriteOptions::default()
72 .try_with_compression(None)
73 .unwrap();
74
75 Self {
76 write_options,
77 data_gen: writer::IpcDataGenerator::default(),
78 dictionary_tracker: writer::DictionaryTracker::new(false),
79 }
80 }
81
82 pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
84 SchemaAsIpc::new(schema, &self.write_options).into()
85 }
86
87 pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
93 match flight_message {
94 FlightMessage::Schema(schema) => {
95 schema.fields().iter().for_each(|x| {
96 if matches!(x.data_type(), DataType::Dictionary(_, _)) {
97 self.dictionary_tracker.next_dict_id();
98 }
99 });
100
101 vec1![self.encode_schema(schema.as_ref())]
102 }
103 FlightMessage::RecordBatch(record_batch) => {
104 let (encoded_dictionaries, encoded_batch) = self
105 .data_gen
106 .encode(
107 &record_batch,
108 &mut self.dictionary_tracker,
109 &self.write_options,
110 &mut Default::default(),
111 )
112 .expect("DictionaryTracker configured above to not fail on replacement");
113
114 Vec1::from_vec_push(
115 encoded_dictionaries.into_iter().map(Into::into).collect(),
116 encoded_batch.into(),
117 )
118 }
119 FlightMessage::AffectedRows(rows) => {
120 let metadata = FlightMetadata {
121 affected_rows: Some(AffectedRows { value: rows as _ }),
122 metrics: None,
123 }
124 .encode_to_vec();
125 vec1![FlightData {
126 flight_descriptor: None,
127 data_header: build_none_flight_msg().into(),
128 app_metadata: metadata.into(),
129 data_body: ProstBytes::default(),
130 }]
131 }
132 FlightMessage::Metrics(s) => {
133 let metadata = FlightMetadata {
134 affected_rows: None,
135 metrics: Some(Metrics {
136 metrics: s.as_bytes().to_vec(),
137 }),
138 }
139 .encode_to_vec();
140 vec1![FlightData {
141 flight_descriptor: None,
142 data_header: build_none_flight_msg().into(),
143 app_metadata: metadata.into(),
144 data_body: ProstBytes::default(),
145 }]
146 }
147 }
148 }
149}
150
151#[derive(Default)]
152pub struct FlightDecoder {
153 schema: Option<SchemaRef>,
154 schema_bytes: Option<bytes::Bytes>,
155 dictionaries_by_id: HashMap<i64, ArrayRef>,
156}
157
158impl FlightDecoder {
159 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
161 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
162 .context(error::ArrowSnafu)?;
163 Ok(Self {
164 schema: Some(Arc::new(arrow_schema)),
165 schema_bytes: Some(schema_bytes.clone()),
166 dictionaries_by_id: HashMap::new(),
167 })
168 }
169
170 pub fn try_decode_record_batch(
171 &mut self,
172 data_header: &bytes::Bytes,
173 data_body: &bytes::Bytes,
174 ) -> Result<DfRecordBatch> {
175 let schema = self
176 .schema
177 .as_ref()
178 .context(InvalidFlightDataSnafu {
179 reason: "Should have decoded schema first!",
180 })?
181 .clone();
182 let message = root_as_message(&data_header[..])
183 .map_err(|err| {
184 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
185 })
186 .context(error::ArrowSnafu)?;
187 let result = message
188 .header_as_record_batch()
189 .ok_or_else(|| {
190 ArrowError::ParseError(
191 "Unable to convert flight data header to a record batch".to_string(),
192 )
193 })
194 .and_then(|batch| {
195 reader::read_record_batch(
196 &Buffer::from(data_body.as_ref()),
197 batch,
198 schema,
199 &HashMap::new(),
200 None,
201 &message.version(),
202 )
203 })
204 .context(error::ArrowSnafu)?;
205 Ok(result)
206 }
207
208 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
215 let message = root_as_message(&flight_data.data_header).map_err(|e| {
216 InvalidFlightDataSnafu {
217 reason: e.to_string(),
218 }
219 .build()
220 })?;
221 match message.header_type() {
222 MessageHeader::NONE => {
223 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
224 .context(DecodeFlightDataSnafu)?;
225 if let Some(AffectedRows { value }) = metadata.affected_rows {
226 return Ok(Some(FlightMessage::AffectedRows(value as _)));
227 }
228 if let Some(Metrics { metrics }) = metadata.metrics {
229 return Ok(Some(FlightMessage::Metrics(
230 String::from_utf8_lossy(&metrics).to_string(),
231 )));
232 }
233 InvalidFlightDataSnafu {
234 reason: "Expecting FlightMetadata have some meaningful content.",
235 }
236 .fail()
237 }
238 MessageHeader::Schema => {
239 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
240 InvalidFlightDataSnafu {
241 reason: e.to_string(),
242 }
243 .build()
244 })?);
245 self.schema = Some(arrow_schema.clone());
246 self.schema_bytes = Some(flight_data.data_header.clone());
247 Ok(Some(FlightMessage::Schema(arrow_schema)))
248 }
249 MessageHeader::RecordBatch => {
250 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
251 reason: "Should have decoded schema first!",
252 })?;
253 let arrow_batch = flight_data_to_arrow_batch(
254 flight_data,
255 schema.clone(),
256 &self.dictionaries_by_id,
257 )
258 .map_err(|e| {
259 InvalidFlightDataSnafu {
260 reason: e.to_string(),
261 }
262 .build()
263 })?;
264 Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
265 }
266 MessageHeader::DictionaryBatch => {
267 let dictionary_batch =
268 message
269 .header_as_dictionary_batch()
270 .context(InvalidFlightDataSnafu {
271 reason: "could not get dictionary batch from DictionaryBatch message",
272 })?;
273
274 let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
275 reason: "schema message is not present previously",
276 })?;
277
278 reader::read_dictionary(
279 &flight_data.data_body.clone().into(),
280 dictionary_batch,
281 schema,
282 &mut self.dictionaries_by_id,
283 &message.version(),
284 )
285 .context(error::ArrowSnafu)?;
286 Ok(None)
287 }
288 other => {
289 let name = other.variant_name().unwrap_or("UNKNOWN");
290 InvalidFlightDataSnafu {
291 reason: format!("Unsupported FlightData type: {name}"),
292 }
293 .fail()
294 }
295 }
296 }
297
298 pub fn schema(&self) -> Option<&SchemaRef> {
299 self.schema.as_ref()
300 }
301
302 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
303 self.schema_bytes.clone()
304 }
305}
306
307pub fn flight_messages_to_recordbatches(
308 messages: Vec<FlightMessage>,
309) -> Result<Vec<DfRecordBatch>> {
310 if messages.is_empty() {
311 Ok(vec![])
312 } else {
313 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
314
315 match &messages[0] {
316 FlightMessage::Schema(_schema) => {}
317 _ => {
318 return InvalidFlightDataSnafu {
319 reason: "First Flight Message must be schema!",
320 }
321 .fail();
322 }
323 };
324
325 for message in messages.into_iter().skip(1) {
326 match message {
327 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
328 _ => {
329 return InvalidFlightDataSnafu {
330 reason: "Expect the following Flight Messages are all Recordbatches!",
331 }
332 .fail();
333 }
334 }
335 }
336
337 Ok(recordbatches)
338 }
339}
340
341fn build_none_flight_msg() -> Bytes {
342 let mut builder = FlatBufferBuilder::new();
343
344 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
345 message.add_version(arrow::ipc::MetadataVersion::V5);
346 message.add_header_type(MessageHeader::NONE);
347 message.add_bodyLength(0);
348
349 let data = message.finish();
350 builder.finish(data, None);
351
352 builder.finished_data().into()
353}
354
355#[cfg(test)]
356mod test {
357 use arrow_flight::utils::batches_to_flight_data;
358 use datatypes::arrow::array::{
359 DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
360 };
361 use datatypes::arrow::datatypes::{DataType, Field, Schema};
362
363 use super::*;
364 use crate::Error;
365
366 #[test]
367 fn test_try_decode() -> Result<()> {
368 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
369 "n",
370 DataType::Int32,
371 true,
372 )]));
373
374 let batch1 = DfRecordBatch::try_new(
375 schema.clone(),
376 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
377 )
378 .unwrap();
379 let batch2 = DfRecordBatch::try_new(
380 schema.clone(),
381 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
382 )
383 .unwrap();
384
385 let flight_data =
386 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
387 assert_eq!(flight_data.len(), 3);
388 let [d1, d2, d3] = flight_data.as_slice() else {
389 unreachable!()
390 };
391
392 let decoder = &mut FlightDecoder::default();
393 assert!(decoder.schema.is_none());
394
395 let result = decoder.try_decode(d2);
396 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
397 assert!(
398 result
399 .unwrap_err()
400 .to_string()
401 .contains("Should have decoded schema first!")
402 );
403
404 let message = decoder.try_decode(d1)?.unwrap();
405 assert!(matches!(message, FlightMessage::Schema(_)));
406 let FlightMessage::Schema(decoded_schema) = message else {
407 unreachable!()
408 };
409 assert_eq!(decoded_schema, schema);
410
411 let _ = decoder.schema.as_ref().unwrap();
412
413 let message = decoder.try_decode(d2)?.unwrap();
414 assert!(matches!(message, FlightMessage::RecordBatch(_)));
415 let FlightMessage::RecordBatch(actual_batch) = message else {
416 unreachable!()
417 };
418 assert_eq!(actual_batch, batch1);
419
420 let message = decoder.try_decode(d3)?.unwrap();
421 assert!(matches!(message, FlightMessage::RecordBatch(_)));
422 let FlightMessage::RecordBatch(actual_batch) = message else {
423 unreachable!()
424 };
425 assert_eq!(actual_batch, batch2);
426 Ok(())
427 }
428
429 #[test]
430 fn test_flight_messages_to_recordbatches() {
431 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
432 let batch1 = DfRecordBatch::try_new(
433 schema.clone(),
434 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
435 )
436 .unwrap();
437 let batch2 = DfRecordBatch::try_new(
438 schema.clone(),
439 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
440 )
441 .unwrap();
442 let recordbatches = vec![batch1.clone(), batch2.clone()];
443
444 let m1 = FlightMessage::Schema(schema);
445 let m2 = FlightMessage::RecordBatch(batch1);
446 let m3 = FlightMessage::RecordBatch(batch2);
447
448 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
449 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
450 assert!(
451 result
452 .unwrap_err()
453 .to_string()
454 .contains("First Flight Message must be schema!")
455 );
456
457 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
458 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
459 assert!(
460 result
461 .unwrap_err()
462 .to_string()
463 .contains("Expect the following Flight Messages are all Recordbatches!")
464 );
465
466 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
467 assert_eq!(actual, recordbatches);
468 }
469
470 #[test]
471 fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
472 let schema = Arc::new(Schema::new(vec![
473 Field::new("i", DataType::UInt8, true),
474 Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
475 ]));
476 let batch1 = DfRecordBatch::try_new(
477 schema.clone(),
478 vec![
479 Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
480 Arc::new(DictionaryArray::new(
481 UInt32Array::from_value(0, 3),
482 Arc::new(StringArray::from_iter_values(["x"])),
483 )) as _,
484 ],
485 )
486 .unwrap();
487 let batch2 = DfRecordBatch::try_new(
488 schema.clone(),
489 vec![
490 Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
491 Arc::new(DictionaryArray::new(
492 UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
493 Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
494 )) as _,
495 ],
496 )
497 .unwrap();
498
499 let message_1 = FlightMessage::Schema(schema.clone());
500 let message_2 = FlightMessage::RecordBatch(batch1);
501 let message_3 = FlightMessage::RecordBatch(batch2);
502
503 let mut encoder = FlightEncoder::default();
504 let encoded_1 = encoder.encode(message_1);
505 let encoded_2 = encoder.encode(message_2);
506 let encoded_3 = encoder.encode(message_3);
507 assert_eq!(encoded_1.len(), 1);
509 assert_eq!(encoded_2.len(), 2);
512 assert_eq!(encoded_3.len(), 2);
513
514 let mut decoder = FlightDecoder::default();
515 let decoded_1 = decoder.try_decode(encoded_1.first())?;
516 let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
517 unreachable!()
518 };
519 assert_eq!(actual_schema, schema);
520 let decoded_2 = decoder.try_decode(&encoded_2[0])?;
521 assert!(decoded_2.is_none());
523 let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
524 unreachable!()
525 };
526 let decoded_3 = decoder.try_decode(&encoded_3[0])?;
527 assert!(decoded_3.is_none());
529 let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
530 unreachable!()
531 };
532 let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
533 .unwrap()
534 .to_string();
535 let expected = r"
536+---+---+
537| i | s |
538+---+---+
539| 1 | x |
540| 2 | x |
541| 3 | x |
542| 4 | h |
543| 5 | e |
544| 6 | l |
545| 7 | l |
546| 8 | o |
547+---+---+";
548 assert_eq!(actual, expected.trim());
549 Ok(())
550 }
551}