1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::buffer::Buffer;
27use datatypes::arrow::datatypes::{Schema as ArrowSchema, SchemaRef};
28use datatypes::arrow::error::ArrowError;
29use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
30use flatbuffers::FlatBufferBuilder;
31use prost::bytes::Bytes as ProstBytes;
32use prost::Message;
33use snafu::{OptionExt, ResultExt};
34
35use crate::error;
36use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
37
38#[derive(Debug, Clone)]
39pub enum FlightMessage {
40 Schema(SchemaRef),
41 RecordBatch(DfRecordBatch),
42 AffectedRows(usize),
43 Metrics(String),
44}
45
46pub struct FlightEncoder {
47 write_options: writer::IpcWriteOptions,
48 data_gen: writer::IpcDataGenerator,
49 dictionary_tracker: writer::DictionaryTracker,
50}
51
52impl Default for FlightEncoder {
53 fn default() -> Self {
54 let write_options = writer::IpcWriteOptions::default()
55 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
56 .unwrap();
57
58 Self {
59 write_options,
60 data_gen: writer::IpcDataGenerator::default(),
61 dictionary_tracker: writer::DictionaryTracker::new(false),
62 }
63 }
64}
65
66impl FlightEncoder {
67 pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
68 match flight_message {
69 FlightMessage::Schema(schema) => SchemaAsIpc::new(&schema, &self.write_options).into(),
70 FlightMessage::RecordBatch(record_batch) => {
71 let (encoded_dictionaries, encoded_batch) = self
72 .data_gen
73 .encoded_batch(
74 &record_batch,
75 &mut self.dictionary_tracker,
76 &self.write_options,
77 )
78 .expect("DictionaryTracker configured above to not fail on replacement");
79
80 debug_assert_eq!(encoded_dictionaries.len(), 0);
86
87 encoded_batch.into()
88 }
89 FlightMessage::AffectedRows(rows) => {
90 let metadata = FlightMetadata {
91 affected_rows: Some(AffectedRows { value: rows as _ }),
92 metrics: None,
93 }
94 .encode_to_vec();
95 FlightData {
96 flight_descriptor: None,
97 data_header: build_none_flight_msg().into(),
98 app_metadata: metadata.into(),
99 data_body: ProstBytes::default(),
100 }
101 }
102 FlightMessage::Metrics(s) => {
103 let metadata = FlightMetadata {
104 affected_rows: None,
105 metrics: Some(Metrics {
106 metrics: s.as_bytes().to_vec(),
107 }),
108 }
109 .encode_to_vec();
110 FlightData {
111 flight_descriptor: None,
112 data_header: build_none_flight_msg().into(),
113 app_metadata: metadata.into(),
114 data_body: ProstBytes::default(),
115 }
116 }
117 }
118 }
119}
120
121#[derive(Default)]
122pub struct FlightDecoder {
123 schema: Option<SchemaRef>,
124 schema_bytes: Option<bytes::Bytes>,
125}
126
127impl FlightDecoder {
128 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
130 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
131 .context(error::ArrowSnafu)?;
132 Ok(Self {
133 schema: Some(Arc::new(arrow_schema)),
134 schema_bytes: Some(schema_bytes.clone()),
135 })
136 }
137
138 pub fn try_decode_record_batch(
139 &mut self,
140 data_header: &bytes::Bytes,
141 data_body: &bytes::Bytes,
142 ) -> Result<DfRecordBatch> {
143 let schema = self
144 .schema
145 .as_ref()
146 .context(InvalidFlightDataSnafu {
147 reason: "Should have decoded schema first!",
148 })?
149 .clone();
150 let message = root_as_message(&data_header[..])
151 .map_err(|err| {
152 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
153 })
154 .context(error::ArrowSnafu)?;
155 let result = message
156 .header_as_record_batch()
157 .ok_or_else(|| {
158 ArrowError::ParseError(
159 "Unable to convert flight data header to a record batch".to_string(),
160 )
161 })
162 .and_then(|batch| {
163 reader::read_record_batch(
164 &Buffer::from(data_body.as_ref()),
165 batch,
166 schema,
167 &HashMap::new(),
168 None,
169 &message.version(),
170 )
171 })
172 .context(error::ArrowSnafu)?;
173 Ok(result)
174 }
175
176 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
177 let message = root_as_message(&flight_data.data_header).map_err(|e| {
178 InvalidFlightDataSnafu {
179 reason: e.to_string(),
180 }
181 .build()
182 })?;
183 match message.header_type() {
184 MessageHeader::NONE => {
185 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
186 .context(DecodeFlightDataSnafu)?;
187 if let Some(AffectedRows { value }) = metadata.affected_rows {
188 return Ok(FlightMessage::AffectedRows(value as _));
189 }
190 if let Some(Metrics { metrics }) = metadata.metrics {
191 return Ok(FlightMessage::Metrics(
192 String::from_utf8_lossy(&metrics).to_string(),
193 ));
194 }
195 InvalidFlightDataSnafu {
196 reason: "Expecting FlightMetadata have some meaningful content.",
197 }
198 .fail()
199 }
200 MessageHeader::Schema => {
201 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
202 InvalidFlightDataSnafu {
203 reason: e.to_string(),
204 }
205 .build()
206 })?);
207 self.schema = Some(arrow_schema.clone());
208 self.schema_bytes = Some(flight_data.data_header.clone());
209 Ok(FlightMessage::Schema(arrow_schema))
210 }
211 MessageHeader::RecordBatch => {
212 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
213 reason: "Should have decoded schema first!",
214 })?;
215 let arrow_batch =
216 flight_data_to_arrow_batch(flight_data, schema.clone(), &HashMap::new())
217 .map_err(|e| {
218 InvalidFlightDataSnafu {
219 reason: e.to_string(),
220 }
221 .build()
222 })?;
223 Ok(FlightMessage::RecordBatch(arrow_batch))
224 }
225 other => {
226 let name = other.variant_name().unwrap_or("UNKNOWN");
227 InvalidFlightDataSnafu {
228 reason: format!("Unsupported FlightData type: {name}"),
229 }
230 .fail()
231 }
232 }
233 }
234
235 pub fn schema(&self) -> Option<&SchemaRef> {
236 self.schema.as_ref()
237 }
238
239 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
240 self.schema_bytes.clone()
241 }
242}
243
244pub fn flight_messages_to_recordbatches(
245 messages: Vec<FlightMessage>,
246) -> Result<Vec<DfRecordBatch>> {
247 if messages.is_empty() {
248 Ok(vec![])
249 } else {
250 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
251
252 match &messages[0] {
253 FlightMessage::Schema(_schema) => {}
254 _ => {
255 return InvalidFlightDataSnafu {
256 reason: "First Flight Message must be schema!",
257 }
258 .fail()
259 }
260 };
261
262 for message in messages.into_iter().skip(1) {
263 match message {
264 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
265 _ => {
266 return InvalidFlightDataSnafu {
267 reason: "Expect the following Flight Messages are all Recordbatches!",
268 }
269 .fail()
270 }
271 }
272 }
273
274 Ok(recordbatches)
275 }
276}
277
278fn build_none_flight_msg() -> Bytes {
279 let mut builder = FlatBufferBuilder::new();
280
281 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
282 message.add_version(arrow::ipc::MetadataVersion::V5);
283 message.add_header_type(MessageHeader::NONE);
284 message.add_bodyLength(0);
285
286 let data = message.finish();
287 builder.finish(data, None);
288
289 builder.finished_data().into()
290}
291
292#[cfg(test)]
293mod test {
294 use arrow_flight::utils::batches_to_flight_data;
295 use datatypes::arrow::array::Int32Array;
296 use datatypes::arrow::datatypes::{DataType, Field, Schema};
297
298 use super::*;
299 use crate::Error;
300
301 #[test]
302 fn test_try_decode() {
303 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
304 "n",
305 DataType::Int32,
306 true,
307 )]));
308
309 let batch1 = DfRecordBatch::try_new(
310 schema.clone(),
311 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
312 )
313 .unwrap();
314 let batch2 = DfRecordBatch::try_new(
315 schema.clone(),
316 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
317 )
318 .unwrap();
319
320 let flight_data =
321 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
322 assert_eq!(flight_data.len(), 3);
323 let [d1, d2, d3] = flight_data.as_slice() else {
324 unreachable!()
325 };
326
327 let decoder = &mut FlightDecoder::default();
328 assert!(decoder.schema.is_none());
329
330 let result = decoder.try_decode(d2);
331 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
332 assert!(result
333 .unwrap_err()
334 .to_string()
335 .contains("Should have decoded schema first!"));
336
337 let message = decoder.try_decode(d1).unwrap();
338 assert!(matches!(message, FlightMessage::Schema(_)));
339 let FlightMessage::Schema(decoded_schema) = message else {
340 unreachable!()
341 };
342 assert_eq!(decoded_schema, schema);
343
344 let _ = decoder.schema.as_ref().unwrap();
345
346 let message = decoder.try_decode(d2).unwrap();
347 assert!(matches!(message, FlightMessage::RecordBatch(_)));
348 let FlightMessage::RecordBatch(actual_batch) = message else {
349 unreachable!()
350 };
351 assert_eq!(actual_batch, batch1);
352
353 let message = decoder.try_decode(d3).unwrap();
354 assert!(matches!(message, FlightMessage::RecordBatch(_)));
355 let FlightMessage::RecordBatch(actual_batch) = message else {
356 unreachable!()
357 };
358 assert_eq!(actual_batch, batch2);
359 }
360
361 #[test]
362 fn test_flight_messages_to_recordbatches() {
363 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
364 let batch1 = DfRecordBatch::try_new(
365 schema.clone(),
366 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
367 )
368 .unwrap();
369 let batch2 = DfRecordBatch::try_new(
370 schema.clone(),
371 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
372 )
373 .unwrap();
374 let recordbatches = vec![batch1.clone(), batch2.clone()];
375
376 let m1 = FlightMessage::Schema(schema);
377 let m2 = FlightMessage::RecordBatch(batch1);
378 let m3 = FlightMessage::RecordBatch(batch2);
379
380 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
381 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
382 assert!(result
383 .unwrap_err()
384 .to_string()
385 .contains("First Flight Message must be schema!"));
386
387 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
388 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
389 assert!(result
390 .unwrap_err()
391 .to_string()
392 .contains("Expect the following Flight Messages are all Recordbatches!"));
393
394 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
395 assert_eq!(actual, recordbatches);
396 }
397}