1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::buffer::Buffer;
27use datatypes::arrow::datatypes::{Schema as ArrowSchema, SchemaRef};
28use datatypes::arrow::error::ArrowError;
29use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
30use flatbuffers::FlatBufferBuilder;
31use prost::bytes::Bytes as ProstBytes;
32use prost::Message;
33use snafu::{OptionExt, ResultExt};
34
35use crate::error;
36use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
37
38#[derive(Debug, Clone)]
39pub enum FlightMessage {
40 Schema(SchemaRef),
41 RecordBatch(DfRecordBatch),
42 AffectedRows(usize),
43 Metrics(String),
44}
45
46pub struct FlightEncoder {
47 write_options: writer::IpcWriteOptions,
48 data_gen: writer::IpcDataGenerator,
49 dictionary_tracker: writer::DictionaryTracker,
50}
51
52impl Default for FlightEncoder {
53 fn default() -> Self {
54 let write_options = writer::IpcWriteOptions::default()
55 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
56 .unwrap();
57
58 Self {
59 write_options,
60 data_gen: writer::IpcDataGenerator::default(),
61 dictionary_tracker: writer::DictionaryTracker::new(false),
62 }
63 }
64}
65
66impl FlightEncoder {
67 pub fn with_compression_disabled() -> Self {
69 let write_options = writer::IpcWriteOptions::default()
70 .try_with_compression(None)
71 .unwrap();
72
73 Self {
74 write_options,
75 data_gen: writer::IpcDataGenerator::default(),
76 dictionary_tracker: writer::DictionaryTracker::new(false),
77 }
78 }
79
80 pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
81 match flight_message {
82 FlightMessage::Schema(schema) => SchemaAsIpc::new(&schema, &self.write_options).into(),
83 FlightMessage::RecordBatch(record_batch) => {
84 let (encoded_dictionaries, encoded_batch) = self
85 .data_gen
86 .encoded_batch(
87 &record_batch,
88 &mut self.dictionary_tracker,
89 &self.write_options,
90 )
91 .expect("DictionaryTracker configured above to not fail on replacement");
92
93 debug_assert_eq!(encoded_dictionaries.len(), 0);
99
100 encoded_batch.into()
101 }
102 FlightMessage::AffectedRows(rows) => {
103 let metadata = FlightMetadata {
104 affected_rows: Some(AffectedRows { value: rows as _ }),
105 metrics: None,
106 }
107 .encode_to_vec();
108 FlightData {
109 flight_descriptor: None,
110 data_header: build_none_flight_msg().into(),
111 app_metadata: metadata.into(),
112 data_body: ProstBytes::default(),
113 }
114 }
115 FlightMessage::Metrics(s) => {
116 let metadata = FlightMetadata {
117 affected_rows: None,
118 metrics: Some(Metrics {
119 metrics: s.as_bytes().to_vec(),
120 }),
121 }
122 .encode_to_vec();
123 FlightData {
124 flight_descriptor: None,
125 data_header: build_none_flight_msg().into(),
126 app_metadata: metadata.into(),
127 data_body: ProstBytes::default(),
128 }
129 }
130 }
131 }
132}
133
134#[derive(Default)]
135pub struct FlightDecoder {
136 schema: Option<SchemaRef>,
137 schema_bytes: Option<bytes::Bytes>,
138}
139
140impl FlightDecoder {
141 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
143 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
144 .context(error::ArrowSnafu)?;
145 Ok(Self {
146 schema: Some(Arc::new(arrow_schema)),
147 schema_bytes: Some(schema_bytes.clone()),
148 })
149 }
150
151 pub fn try_decode_record_batch(
152 &mut self,
153 data_header: &bytes::Bytes,
154 data_body: &bytes::Bytes,
155 ) -> Result<DfRecordBatch> {
156 let schema = self
157 .schema
158 .as_ref()
159 .context(InvalidFlightDataSnafu {
160 reason: "Should have decoded schema first!",
161 })?
162 .clone();
163 let message = root_as_message(&data_header[..])
164 .map_err(|err| {
165 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
166 })
167 .context(error::ArrowSnafu)?;
168 let result = message
169 .header_as_record_batch()
170 .ok_or_else(|| {
171 ArrowError::ParseError(
172 "Unable to convert flight data header to a record batch".to_string(),
173 )
174 })
175 .and_then(|batch| {
176 reader::read_record_batch(
177 &Buffer::from(data_body.as_ref()),
178 batch,
179 schema,
180 &HashMap::new(),
181 None,
182 &message.version(),
183 )
184 })
185 .context(error::ArrowSnafu)?;
186 Ok(result)
187 }
188
189 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
190 let message = root_as_message(&flight_data.data_header).map_err(|e| {
191 InvalidFlightDataSnafu {
192 reason: e.to_string(),
193 }
194 .build()
195 })?;
196 match message.header_type() {
197 MessageHeader::NONE => {
198 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
199 .context(DecodeFlightDataSnafu)?;
200 if let Some(AffectedRows { value }) = metadata.affected_rows {
201 return Ok(FlightMessage::AffectedRows(value as _));
202 }
203 if let Some(Metrics { metrics }) = metadata.metrics {
204 return Ok(FlightMessage::Metrics(
205 String::from_utf8_lossy(&metrics).to_string(),
206 ));
207 }
208 InvalidFlightDataSnafu {
209 reason: "Expecting FlightMetadata have some meaningful content.",
210 }
211 .fail()
212 }
213 MessageHeader::Schema => {
214 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
215 InvalidFlightDataSnafu {
216 reason: e.to_string(),
217 }
218 .build()
219 })?);
220 self.schema = Some(arrow_schema.clone());
221 self.schema_bytes = Some(flight_data.data_header.clone());
222 Ok(FlightMessage::Schema(arrow_schema))
223 }
224 MessageHeader::RecordBatch => {
225 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
226 reason: "Should have decoded schema first!",
227 })?;
228 let arrow_batch =
229 flight_data_to_arrow_batch(flight_data, schema.clone(), &HashMap::new())
230 .map_err(|e| {
231 InvalidFlightDataSnafu {
232 reason: e.to_string(),
233 }
234 .build()
235 })?;
236 Ok(FlightMessage::RecordBatch(arrow_batch))
237 }
238 other => {
239 let name = other.variant_name().unwrap_or("UNKNOWN");
240 InvalidFlightDataSnafu {
241 reason: format!("Unsupported FlightData type: {name}"),
242 }
243 .fail()
244 }
245 }
246 }
247
248 pub fn schema(&self) -> Option<&SchemaRef> {
249 self.schema.as_ref()
250 }
251
252 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
253 self.schema_bytes.clone()
254 }
255}
256
257pub fn flight_messages_to_recordbatches(
258 messages: Vec<FlightMessage>,
259) -> Result<Vec<DfRecordBatch>> {
260 if messages.is_empty() {
261 Ok(vec![])
262 } else {
263 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
264
265 match &messages[0] {
266 FlightMessage::Schema(_schema) => {}
267 _ => {
268 return InvalidFlightDataSnafu {
269 reason: "First Flight Message must be schema!",
270 }
271 .fail()
272 }
273 };
274
275 for message in messages.into_iter().skip(1) {
276 match message {
277 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
278 _ => {
279 return InvalidFlightDataSnafu {
280 reason: "Expect the following Flight Messages are all Recordbatches!",
281 }
282 .fail()
283 }
284 }
285 }
286
287 Ok(recordbatches)
288 }
289}
290
291fn build_none_flight_msg() -> Bytes {
292 let mut builder = FlatBufferBuilder::new();
293
294 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
295 message.add_version(arrow::ipc::MetadataVersion::V5);
296 message.add_header_type(MessageHeader::NONE);
297 message.add_bodyLength(0);
298
299 let data = message.finish();
300 builder.finish(data, None);
301
302 builder.finished_data().into()
303}
304
305#[cfg(test)]
306mod test {
307 use arrow_flight::utils::batches_to_flight_data;
308 use datatypes::arrow::array::Int32Array;
309 use datatypes::arrow::datatypes::{DataType, Field, Schema};
310
311 use super::*;
312 use crate::Error;
313
314 #[test]
315 fn test_try_decode() {
316 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
317 "n",
318 DataType::Int32,
319 true,
320 )]));
321
322 let batch1 = DfRecordBatch::try_new(
323 schema.clone(),
324 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
325 )
326 .unwrap();
327 let batch2 = DfRecordBatch::try_new(
328 schema.clone(),
329 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
330 )
331 .unwrap();
332
333 let flight_data =
334 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
335 assert_eq!(flight_data.len(), 3);
336 let [d1, d2, d3] = flight_data.as_slice() else {
337 unreachable!()
338 };
339
340 let decoder = &mut FlightDecoder::default();
341 assert!(decoder.schema.is_none());
342
343 let result = decoder.try_decode(d2);
344 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
345 assert!(result
346 .unwrap_err()
347 .to_string()
348 .contains("Should have decoded schema first!"));
349
350 let message = decoder.try_decode(d1).unwrap();
351 assert!(matches!(message, FlightMessage::Schema(_)));
352 let FlightMessage::Schema(decoded_schema) = message else {
353 unreachable!()
354 };
355 assert_eq!(decoded_schema, schema);
356
357 let _ = decoder.schema.as_ref().unwrap();
358
359 let message = decoder.try_decode(d2).unwrap();
360 assert!(matches!(message, FlightMessage::RecordBatch(_)));
361 let FlightMessage::RecordBatch(actual_batch) = message else {
362 unreachable!()
363 };
364 assert_eq!(actual_batch, batch1);
365
366 let message = decoder.try_decode(d3).unwrap();
367 assert!(matches!(message, FlightMessage::RecordBatch(_)));
368 let FlightMessage::RecordBatch(actual_batch) = message else {
369 unreachable!()
370 };
371 assert_eq!(actual_batch, batch2);
372 }
373
374 #[test]
375 fn test_flight_messages_to_recordbatches() {
376 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
377 let batch1 = DfRecordBatch::try_new(
378 schema.clone(),
379 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
380 )
381 .unwrap();
382 let batch2 = DfRecordBatch::try_new(
383 schema.clone(),
384 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
385 )
386 .unwrap();
387 let recordbatches = vec![batch1.clone(), batch2.clone()];
388
389 let m1 = FlightMessage::Schema(schema);
390 let m2 = FlightMessage::RecordBatch(batch1);
391 let m3 = FlightMessage::RecordBatch(batch2);
392
393 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
394 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
395 assert!(result
396 .unwrap_err()
397 .to_string()
398 .contains("First Flight Message must be schema!"));
399
400 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
401 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
402 assert!(result
403 .unwrap_err()
404 .to_string()
405 .contains("Expect the following Flight Messages are all Recordbatches!"));
406
407 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
408 assert_eq!(actual, recordbatches);
409 }
410}