1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::{RecordBatch, RecordBatches};
25use datatypes::arrow;
26use datatypes::arrow::datatypes::Schema as ArrowSchema;
27use datatypes::arrow::ipc::{root_as_message, writer, MessageHeader};
28use datatypes::schema::{Schema, SchemaRef};
29use flatbuffers::FlatBufferBuilder;
30use prost::bytes::Bytes as ProstBytes;
31use prost::Message;
32use snafu::{OptionExt, ResultExt};
33
34use crate::error::{
35 ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu,
36 Result,
37};
38
39#[derive(Debug, Clone)]
40pub enum FlightMessage {
41 Schema(SchemaRef),
42 Recordbatch(RecordBatch),
43 AffectedRows(usize),
44 Metrics(String),
45}
46
47pub struct FlightEncoder {
48 write_options: writer::IpcWriteOptions,
49 data_gen: writer::IpcDataGenerator,
50 dictionary_tracker: writer::DictionaryTracker,
51}
52
53impl Default for FlightEncoder {
54 fn default() -> Self {
55 let write_options = writer::IpcWriteOptions::default()
56 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
57 .unwrap();
58
59 Self {
60 write_options,
61 data_gen: writer::IpcDataGenerator::default(),
62 dictionary_tracker: writer::DictionaryTracker::new(false),
63 }
64 }
65}
66
67impl FlightEncoder {
68 pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
69 match flight_message {
70 FlightMessage::Schema(schema) => {
71 SchemaAsIpc::new(schema.arrow_schema(), &self.write_options).into()
72 }
73 FlightMessage::Recordbatch(recordbatch) => {
74 let (encoded_dictionaries, encoded_batch) = self
75 .data_gen
76 .encoded_batch(
77 recordbatch.df_record_batch(),
78 &mut self.dictionary_tracker,
79 &self.write_options,
80 )
81 .expect("DictionaryTracker configured above to not fail on replacement");
82
83 debug_assert_eq!(encoded_dictionaries.len(), 0);
89
90 encoded_batch.into()
91 }
92 FlightMessage::AffectedRows(rows) => {
93 let metadata = FlightMetadata {
94 affected_rows: Some(AffectedRows { value: rows as _ }),
95 metrics: None,
96 }
97 .encode_to_vec();
98 FlightData {
99 flight_descriptor: None,
100 data_header: build_none_flight_msg().into(),
101 app_metadata: metadata.into(),
102 data_body: ProstBytes::default(),
103 }
104 }
105 FlightMessage::Metrics(s) => {
106 let metadata = FlightMetadata {
107 affected_rows: None,
108 metrics: Some(Metrics {
109 metrics: s.as_bytes().to_vec(),
110 }),
111 }
112 .encode_to_vec();
113 FlightData {
114 flight_descriptor: None,
115 data_header: build_none_flight_msg().into(),
116 app_metadata: metadata.into(),
117 data_body: ProstBytes::default(),
118 }
119 }
120 }
121 }
122}
123
124#[derive(Default)]
125pub struct FlightDecoder {
126 schema: Option<SchemaRef>,
127}
128
129impl FlightDecoder {
130 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<FlightMessage> {
131 let message = root_as_message(&flight_data.data_header).map_err(|e| {
132 InvalidFlightDataSnafu {
133 reason: e.to_string(),
134 }
135 .build()
136 })?;
137 match message.header_type() {
138 MessageHeader::NONE => {
139 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
140 .context(DecodeFlightDataSnafu)?;
141 if let Some(AffectedRows { value }) = metadata.affected_rows {
142 return Ok(FlightMessage::AffectedRows(value as _));
143 }
144 if let Some(Metrics { metrics }) = metadata.metrics {
145 return Ok(FlightMessage::Metrics(
146 String::from_utf8_lossy(&metrics).to_string(),
147 ));
148 }
149 InvalidFlightDataSnafu {
150 reason: "Expecting FlightMetadata have some meaningful content.",
151 }
152 .fail()
153 }
154 MessageHeader::Schema => {
155 let arrow_schema = ArrowSchema::try_from(flight_data).map_err(|e| {
156 InvalidFlightDataSnafu {
157 reason: e.to_string(),
158 }
159 .build()
160 })?;
161 let schema =
162 Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?);
163
164 self.schema = Some(schema.clone());
165
166 Ok(FlightMessage::Schema(schema))
167 }
168 MessageHeader::RecordBatch => {
169 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
170 reason: "Should have decoded schema first!",
171 })?;
172 let arrow_schema = schema.arrow_schema().clone();
173
174 let arrow_batch =
175 flight_data_to_arrow_batch(flight_data, arrow_schema, &HashMap::new())
176 .map_err(|e| {
177 InvalidFlightDataSnafu {
178 reason: e.to_string(),
179 }
180 .build()
181 })?;
182 let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch)
183 .context(CreateRecordBatchSnafu)?;
184 Ok(FlightMessage::Recordbatch(recordbatch))
185 }
186 other => {
187 let name = other.variant_name().unwrap_or("UNKNOWN");
188 InvalidFlightDataSnafu {
189 reason: format!("Unsupported FlightData type: {name}"),
190 }
191 .fail()
192 }
193 }
194 }
195
196 pub fn schema(&self) -> Option<&SchemaRef> {
197 self.schema.as_ref()
198 }
199}
200
201pub fn flight_messages_to_recordbatches(messages: Vec<FlightMessage>) -> Result<RecordBatches> {
202 if messages.is_empty() {
203 Ok(RecordBatches::empty())
204 } else {
205 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
206
207 let schema = match &messages[0] {
208 FlightMessage::Schema(schema) => schema.clone(),
209 _ => {
210 return InvalidFlightDataSnafu {
211 reason: "First Flight Message must be schema!",
212 }
213 .fail()
214 }
215 };
216
217 for message in messages.into_iter().skip(1) {
218 match message {
219 FlightMessage::Recordbatch(recordbatch) => recordbatches.push(recordbatch),
220 _ => {
221 return InvalidFlightDataSnafu {
222 reason: "Expect the following Flight Messages are all Recordbatches!",
223 }
224 .fail()
225 }
226 }
227 }
228
229 RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu)
230 }
231}
232
233fn build_none_flight_msg() -> Bytes {
234 let mut builder = FlatBufferBuilder::new();
235
236 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
237 message.add_version(arrow::ipc::MetadataVersion::V5);
238 message.add_header_type(MessageHeader::NONE);
239 message.add_bodyLength(0);
240
241 let data = message.finish();
242 builder.finish(data, None);
243
244 builder.finished_data().into()
245}
246
247#[cfg(test)]
248mod test {
249 use arrow_flight::utils::batches_to_flight_data;
250 use datatypes::arrow::datatypes::{DataType, Field};
251 use datatypes::prelude::ConcreteDataType;
252 use datatypes::schema::ColumnSchema;
253 use datatypes::vectors::Int32Vector;
254
255 use super::*;
256 use crate::Error;
257
258 #[test]
259 fn test_try_decode() {
260 let arrow_schema = ArrowSchema::new(vec![Field::new("n", DataType::Int32, true)]);
261 let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap());
262
263 let batch1 = RecordBatch::new(
264 schema.clone(),
265 vec![Arc::new(Int32Vector::from(vec![Some(1), None, Some(3)])) as _],
266 )
267 .unwrap();
268 let batch2 = RecordBatch::new(
269 schema.clone(),
270 vec![Arc::new(Int32Vector::from(vec![None, Some(5)])) as _],
271 )
272 .unwrap();
273
274 let flight_data = batches_to_flight_data(
275 &arrow_schema,
276 vec![
277 batch1.clone().into_df_record_batch(),
278 batch2.clone().into_df_record_batch(),
279 ],
280 )
281 .unwrap();
282 assert_eq!(flight_data.len(), 3);
283 let [d1, d2, d3] = flight_data.as_slice() else {
284 unreachable!()
285 };
286
287 let decoder = &mut FlightDecoder::default();
288 assert!(decoder.schema.is_none());
289
290 let result = decoder.try_decode(d2);
291 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
292 assert!(result
293 .unwrap_err()
294 .to_string()
295 .contains("Should have decoded schema first!"));
296
297 let message = decoder.try_decode(d1).unwrap();
298 assert!(matches!(message, FlightMessage::Schema(_)));
299 let FlightMessage::Schema(decoded_schema) = message else {
300 unreachable!()
301 };
302 assert_eq!(decoded_schema, schema);
303
304 let _ = decoder.schema.as_ref().unwrap();
305
306 let message = decoder.try_decode(d2).unwrap();
307 assert!(matches!(message, FlightMessage::Recordbatch(_)));
308 let FlightMessage::Recordbatch(actual_batch) = message else {
309 unreachable!()
310 };
311 assert_eq!(actual_batch, batch1);
312
313 let message = decoder.try_decode(d3).unwrap();
314 assert!(matches!(message, FlightMessage::Recordbatch(_)));
315 let FlightMessage::Recordbatch(actual_batch) = message else {
316 unreachable!()
317 };
318 assert_eq!(actual_batch, batch2);
319 }
320
321 #[test]
322 fn test_flight_messages_to_recordbatches() {
323 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
324 "m",
325 ConcreteDataType::int32_datatype(),
326 true,
327 )]));
328 let batch1 = RecordBatch::new(
329 schema.clone(),
330 vec![Arc::new(Int32Vector::from(vec![Some(2), None, Some(4)])) as _],
331 )
332 .unwrap();
333 let batch2 = RecordBatch::new(
334 schema.clone(),
335 vec![Arc::new(Int32Vector::from(vec![None, Some(6)])) as _],
336 )
337 .unwrap();
338 let recordbatches =
339 RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
340
341 let m1 = FlightMessage::Schema(schema);
342 let m2 = FlightMessage::Recordbatch(batch1);
343 let m3 = FlightMessage::Recordbatch(batch2);
344
345 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
346 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
347 assert!(result
348 .unwrap_err()
349 .to_string()
350 .contains("First Flight Message must be schema!"));
351
352 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
353 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
354 assert!(result
355 .unwrap_err()
356 .to_string()
357 .contains("Expect the following Flight Messages are all Recordbatches!"));
358
359 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
360 assert_eq!(actual, recordbatches);
361 }
362}