1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions";
42pub const SNAPSHOT_SEQS_METADATA_KEY: &str = "x-greptime-snapshot-seqs";
44
45#[derive(Debug, Clone)]
46pub enum FlightMessage {
47 Schema(SchemaRef),
48 RecordBatch(DfRecordBatch),
49 AffectedRows {
50 rows: usize,
51 metrics: Option<String>,
52 },
53 Metrics(String),
54}
55
56pub struct FlightEncoder {
57 write_options: writer::IpcWriteOptions,
58 data_gen: writer::IpcDataGenerator,
59 dictionary_tracker: writer::DictionaryTracker,
60}
61
62impl Default for FlightEncoder {
63 fn default() -> Self {
64 let write_options = writer::IpcWriteOptions::default()
65 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
66 .unwrap();
67
68 Self {
69 write_options,
70 data_gen: writer::IpcDataGenerator::default(),
71 dictionary_tracker: writer::DictionaryTracker::new(false),
72 }
73 }
74}
75
76impl FlightEncoder {
77 pub fn with_compression_disabled() -> Self {
79 let write_options = writer::IpcWriteOptions::default()
80 .try_with_compression(None)
81 .unwrap();
82
83 Self {
84 write_options,
85 data_gen: writer::IpcDataGenerator::default(),
86 dictionary_tracker: writer::DictionaryTracker::new(false),
87 }
88 }
89
90 pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
92 SchemaAsIpc::new(schema, &self.write_options).into()
93 }
94
95 pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
101 match flight_message {
102 FlightMessage::Schema(schema) => {
103 schema.fields().iter().for_each(|x| {
104 if matches!(x.data_type(), DataType::Dictionary(_, _)) {
105 self.dictionary_tracker.next_dict_id();
106 }
107 });
108
109 vec1![self.encode_schema(schema.as_ref())]
110 }
111 FlightMessage::RecordBatch(record_batch) => {
112 let (encoded_dictionaries, encoded_batch) = self
113 .data_gen
114 .encode(
115 &record_batch,
116 &mut self.dictionary_tracker,
117 &self.write_options,
118 &mut Default::default(),
119 )
120 .expect("DictionaryTracker configured above to not fail on replacement");
121
122 Vec1::from_vec_push(
123 encoded_dictionaries.into_iter().map(Into::into).collect(),
124 encoded_batch.into(),
125 )
126 }
127 FlightMessage::AffectedRows { rows, metrics } => {
128 let metadata = FlightMetadata {
129 affected_rows: Some(AffectedRows { value: rows as _ }),
130 metrics: metrics.map(|s| Metrics {
131 metrics: s.into_bytes(),
132 }),
133 }
134 .encode_to_vec();
135 vec1![FlightData {
136 flight_descriptor: None,
137 data_header: build_none_flight_msg().into(),
138 app_metadata: metadata.into(),
139 data_body: ProstBytes::default(),
140 }]
141 }
142 FlightMessage::Metrics(s) => {
143 let metadata = FlightMetadata {
144 affected_rows: None,
145 metrics: Some(Metrics {
146 metrics: s.as_bytes().to_vec(),
147 }),
148 }
149 .encode_to_vec();
150 vec1![FlightData {
151 flight_descriptor: None,
152 data_header: build_none_flight_msg().into(),
153 app_metadata: metadata.into(),
154 data_body: ProstBytes::default(),
155 }]
156 }
157 }
158 }
159}
160
161#[derive(Default)]
162pub struct FlightDecoder {
163 schema: Option<SchemaRef>,
164 schema_bytes: Option<bytes::Bytes>,
165 dictionaries_by_id: HashMap<i64, ArrayRef>,
166}
167
168impl FlightDecoder {
169 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
171 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
172 .context(error::ArrowSnafu)?;
173 Ok(Self {
174 schema: Some(Arc::new(arrow_schema)),
175 schema_bytes: Some(schema_bytes.clone()),
176 dictionaries_by_id: HashMap::new(),
177 })
178 }
179
180 pub fn try_decode_record_batch(
181 &mut self,
182 data_header: &bytes::Bytes,
183 data_body: &bytes::Bytes,
184 ) -> Result<DfRecordBatch> {
185 let schema = self
186 .schema
187 .as_ref()
188 .context(InvalidFlightDataSnafu {
189 reason: "Should have decoded schema first!",
190 })?
191 .clone();
192 let message = root_as_message(&data_header[..])
193 .map_err(|err| {
194 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
195 })
196 .context(error::ArrowSnafu)?;
197 let result = message
198 .header_as_record_batch()
199 .ok_or_else(|| {
200 ArrowError::ParseError(
201 "Unable to convert flight data header to a record batch".to_string(),
202 )
203 })
204 .and_then(|batch| {
205 reader::read_record_batch(
206 &Buffer::from(data_body.as_ref()),
207 batch,
208 schema,
209 &HashMap::new(),
210 None,
211 &message.version(),
212 )
213 })
214 .context(error::ArrowSnafu)?;
215 Ok(result)
216 }
217
218 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
225 let message = root_as_message(&flight_data.data_header).map_err(|e| {
226 InvalidFlightDataSnafu {
227 reason: e.to_string(),
228 }
229 .build()
230 })?;
231 match message.header_type() {
232 MessageHeader::NONE => {
233 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
234 .context(DecodeFlightDataSnafu)?;
235 if let Some(AffectedRows { value }) = metadata.affected_rows {
236 return Ok(Some(FlightMessage::AffectedRows {
237 rows: value as _,
238 metrics: metadata
239 .metrics
240 .map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
241 }));
242 }
243 if let Some(Metrics { metrics }) = metadata.metrics {
244 return Ok(Some(FlightMessage::Metrics(
245 String::from_utf8_lossy(&metrics).to_string(),
246 )));
247 }
248 InvalidFlightDataSnafu {
249 reason: "Expecting FlightMetadata have some meaningful content.",
250 }
251 .fail()
252 }
253 MessageHeader::Schema => {
254 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
255 InvalidFlightDataSnafu {
256 reason: e.to_string(),
257 }
258 .build()
259 })?);
260 self.schema = Some(arrow_schema.clone());
261 self.schema_bytes = Some(flight_data.data_header.clone());
262 Ok(Some(FlightMessage::Schema(arrow_schema)))
263 }
264 MessageHeader::RecordBatch => {
265 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
266 reason: "Should have decoded schema first!",
267 })?;
268 let arrow_batch = flight_data_to_arrow_batch(
269 flight_data,
270 schema.clone(),
271 &self.dictionaries_by_id,
272 )
273 .map_err(|e| {
274 InvalidFlightDataSnafu {
275 reason: e.to_string(),
276 }
277 .build()
278 })?;
279 Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
280 }
281 MessageHeader::DictionaryBatch => {
282 let dictionary_batch =
283 message
284 .header_as_dictionary_batch()
285 .context(InvalidFlightDataSnafu {
286 reason: "could not get dictionary batch from DictionaryBatch message",
287 })?;
288
289 let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
290 reason: "schema message is not present previously",
291 })?;
292
293 reader::read_dictionary(
294 &flight_data.data_body.clone().into(),
295 dictionary_batch,
296 schema,
297 &mut self.dictionaries_by_id,
298 &message.version(),
299 )
300 .context(error::ArrowSnafu)?;
301 Ok(None)
302 }
303 other => {
304 let name = other.variant_name().unwrap_or("UNKNOWN");
305 InvalidFlightDataSnafu {
306 reason: format!("Unsupported FlightData type: {name}"),
307 }
308 .fail()
309 }
310 }
311 }
312
313 pub fn schema(&self) -> Option<&SchemaRef> {
314 self.schema.as_ref()
315 }
316
317 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
318 self.schema_bytes.clone()
319 }
320}
321
322pub fn flight_messages_to_recordbatches(
323 messages: Vec<FlightMessage>,
324) -> Result<Vec<DfRecordBatch>> {
325 if messages.is_empty() {
326 Ok(vec![])
327 } else {
328 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
329
330 match &messages[0] {
331 FlightMessage::Schema(_schema) => {}
332 _ => {
333 return InvalidFlightDataSnafu {
334 reason: "First Flight Message must be schema!",
335 }
336 .fail();
337 }
338 };
339
340 for message in messages.into_iter().skip(1) {
341 match message {
342 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
343 _ => {
344 return InvalidFlightDataSnafu {
345 reason: "Expect the following Flight Messages are all Recordbatches!",
346 }
347 .fail();
348 }
349 }
350 }
351
352 Ok(recordbatches)
353 }
354}
355
356fn build_none_flight_msg() -> Bytes {
357 let mut builder = FlatBufferBuilder::new();
358
359 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
360 message.add_version(arrow::ipc::MetadataVersion::V5);
361 message.add_header_type(MessageHeader::NONE);
362 message.add_bodyLength(0);
363
364 let data = message.finish();
365 builder.finish(data, None);
366
367 builder.finished_data().into()
368}
369
370#[cfg(test)]
371mod test {
372 use arrow_flight::utils::batches_to_flight_data;
373 use datatypes::arrow::array::{
374 DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
375 };
376 use datatypes::arrow::datatypes::{DataType, Field, Schema};
377
378 use super::*;
379 use crate::Error;
380
381 #[test]
382 fn test_try_decode() -> Result<()> {
383 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
384 "n",
385 DataType::Int32,
386 true,
387 )]));
388
389 let batch1 = DfRecordBatch::try_new(
390 schema.clone(),
391 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
392 )
393 .unwrap();
394 let batch2 = DfRecordBatch::try_new(
395 schema.clone(),
396 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
397 )
398 .unwrap();
399
400 let flight_data =
401 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
402 assert_eq!(flight_data.len(), 3);
403 let [d1, d2, d3] = flight_data.as_slice() else {
404 unreachable!()
405 };
406
407 let decoder = &mut FlightDecoder::default();
408 assert!(decoder.schema.is_none());
409
410 let result = decoder.try_decode(d2);
411 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
412 assert!(
413 result
414 .unwrap_err()
415 .to_string()
416 .contains("Should have decoded schema first!")
417 );
418
419 let message = decoder.try_decode(d1)?.unwrap();
420 assert!(matches!(message, FlightMessage::Schema(_)));
421 let FlightMessage::Schema(decoded_schema) = message else {
422 unreachable!()
423 };
424 assert_eq!(decoded_schema, schema);
425
426 let _ = decoder.schema.as_ref().unwrap();
427
428 let message = decoder.try_decode(d2)?.unwrap();
429 assert!(matches!(message, FlightMessage::RecordBatch(_)));
430 let FlightMessage::RecordBatch(actual_batch) = message else {
431 unreachable!()
432 };
433 assert_eq!(actual_batch, batch1);
434
435 let message = decoder.try_decode(d3)?.unwrap();
436 assert!(matches!(message, FlightMessage::RecordBatch(_)));
437 let FlightMessage::RecordBatch(actual_batch) = message else {
438 unreachable!()
439 };
440 assert_eq!(actual_batch, batch2);
441 Ok(())
442 }
443
444 #[test]
445 fn test_affected_rows_metrics_encode_decode() -> Result<()> {
446 let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
447 let mut encoder = FlightEncoder::default();
448 let encoded = encoder.encode(FlightMessage::AffectedRows {
449 rows: 3,
450 metrics: Some(metrics.to_string()),
451 });
452
453 assert_eq!(encoded.len(), 1);
454
455 let mut decoder = FlightDecoder::default();
456 let decoded = decoder.try_decode(encoded.first())?.unwrap();
457 let FlightMessage::AffectedRows {
458 rows,
459 metrics: decoded_metrics,
460 } = decoded
461 else {
462 unreachable!()
463 };
464 assert_eq!(rows, 3);
465 assert_eq!(decoded_metrics.as_deref(), Some(metrics));
466
467 let encoded = encoder.encode(FlightMessage::AffectedRows {
468 rows: 5,
469 metrics: None,
470 });
471 let decoded = decoder.try_decode(encoded.first())?.unwrap();
472 let FlightMessage::AffectedRows {
473 rows,
474 metrics: decoded_metrics,
475 } = decoded
476 else {
477 unreachable!()
478 };
479 assert_eq!(rows, 5);
480 assert!(decoded_metrics.is_none());
481
482 Ok(())
483 }
484
485 #[test]
486 fn test_flight_messages_to_recordbatches() {
487 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
488 let batch1 = DfRecordBatch::try_new(
489 schema.clone(),
490 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
491 )
492 .unwrap();
493 let batch2 = DfRecordBatch::try_new(
494 schema.clone(),
495 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
496 )
497 .unwrap();
498 let recordbatches = vec![batch1.clone(), batch2.clone()];
499
500 let m1 = FlightMessage::Schema(schema);
501 let m2 = FlightMessage::RecordBatch(batch1);
502 let m3 = FlightMessage::RecordBatch(batch2);
503
504 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
505 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
506 assert!(
507 result
508 .unwrap_err()
509 .to_string()
510 .contains("First Flight Message must be schema!")
511 );
512
513 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
514 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
515 assert!(
516 result
517 .unwrap_err()
518 .to_string()
519 .contains("Expect the following Flight Messages are all Recordbatches!")
520 );
521
522 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
523 assert_eq!(actual, recordbatches);
524 }
525
526 #[test]
527 fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
528 let schema = Arc::new(Schema::new(vec![
529 Field::new("i", DataType::UInt8, true),
530 Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
531 ]));
532 let batch1 = DfRecordBatch::try_new(
533 schema.clone(),
534 vec![
535 Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
536 Arc::new(DictionaryArray::new(
537 UInt32Array::from_value(0, 3),
538 Arc::new(StringArray::from_iter_values(["x"])),
539 )) as _,
540 ],
541 )
542 .unwrap();
543 let batch2 = DfRecordBatch::try_new(
544 schema.clone(),
545 vec![
546 Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
547 Arc::new(DictionaryArray::new(
548 UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
549 Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
550 )) as _,
551 ],
552 )
553 .unwrap();
554
555 let message_1 = FlightMessage::Schema(schema.clone());
556 let message_2 = FlightMessage::RecordBatch(batch1);
557 let message_3 = FlightMessage::RecordBatch(batch2);
558
559 let mut encoder = FlightEncoder::default();
560 let encoded_1 = encoder.encode(message_1);
561 let encoded_2 = encoder.encode(message_2);
562 let encoded_3 = encoder.encode(message_3);
563 assert_eq!(encoded_1.len(), 1);
565 assert_eq!(encoded_2.len(), 2);
568 assert_eq!(encoded_3.len(), 2);
569
570 let mut decoder = FlightDecoder::default();
571 let decoded_1 = decoder.try_decode(encoded_1.first())?;
572 let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
573 unreachable!()
574 };
575 assert_eq!(actual_schema, schema);
576 let decoded_2 = decoder.try_decode(&encoded_2[0])?;
577 assert!(decoded_2.is_none());
579 let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
580 unreachable!()
581 };
582 let decoded_3 = decoder.try_decode(&encoded_3[0])?;
583 assert!(decoded_3.is_none());
585 let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
586 unreachable!()
587 };
588 let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
589 .unwrap()
590 .to_string();
591 let expected = r"
592+---+---+
593| i | s |
594+---+---+
595| 1 | x |
596| 2 | x |
597| 3 | x |
598| 4 | h |
599| 5 | e |
600| 6 | l |
601| 7 | l |
602| 8 | o |
603+---+---+";
604 assert_eq!(actual, expected.trim());
605 Ok(())
606 }
607
608 #[test]
609 fn test_affected_rows_roundtrip_through_flight_codec() {
610 let mut encoder = FlightEncoder::default();
614 let mut decoder = FlightDecoder::default();
615
616 let encoded = encoder.encode(FlightMessage::AffectedRows {
618 rows: 7,
619 metrics: None,
620 });
621 let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
622 assert!(matches!(
623 decoded,
624 FlightMessage::AffectedRows {
625 rows: 7,
626 metrics: None,
627 }
628 ));
629
630 let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#;
632 let encoded = encoder.encode(FlightMessage::AffectedRows {
633 rows: 42,
634 metrics: Some(json.to_string()),
635 });
636 let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
637 assert!(matches!(
638 decoded,
639 FlightMessage::AffectedRows {
640 rows: 42,
641 metrics: Some(_),
642 }
643 ));
644 }
645
646 #[test]
649 fn test_old_affected_rows_format_decoded_by_new_code() {
650 use arrow_flight::FlightData;
651 use prost::bytes::Bytes as ProstBytes;
652
653 let old_wire_bytes = FlightData {
658 flight_descriptor: None,
659 data_header: build_none_flight_msg().into(),
660 app_metadata: FlightMetadata {
661 affected_rows: Some(AffectedRows { value: 99 }),
662 metrics: None, }
664 .encode_to_vec()
665 .into(),
666 data_body: ProstBytes::default(),
667 };
668
669 let mut decoder = FlightDecoder::default();
670 let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap();
671 assert!(matches!(
672 decoded,
673 FlightMessage::AffectedRows {
674 rows: 99,
675 metrics: None,
676 }
677 ));
678 }
679}