1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::{RecordBatch, RecordBatches};
25use datatypes::arrow;
26use datatypes::arrow::datatypes::Schema as ArrowSchema;
27use datatypes::arrow::ipc::{root_as_message, writer, MessageHeader};
28use datatypes::schema::{Schema, SchemaRef};
29use flatbuffers::FlatBufferBuilder;
30use prost::bytes::Bytes as ProstBytes;
31use prost::Message;
32use snafu::{OptionExt, ResultExt};
33
34use crate::error::{
35 ConvertArrowSchemaSnafu, CreateRecordBatchSnafu, DecodeFlightDataSnafu, InvalidFlightDataSnafu,
36 Result,
37};
38
39#[derive(Debug, Clone)]
40pub enum FlightMessage {
41 Schema(SchemaRef),
42 Recordbatch(RecordBatch),
43 AffectedRows(usize),
44 Metrics(String),
45}
46
47pub struct FlightEncoder {
48 write_options: writer::IpcWriteOptions,
49 data_gen: writer::IpcDataGenerator,
50 dictionary_tracker: writer::DictionaryTracker,
51}
52
53impl Default for FlightEncoder {
54 fn default() -> Self {
55 let write_options = writer::IpcWriteOptions::default()
56 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
57 .unwrap();
58
59 Self {
60 write_options,
61 data_gen: writer::IpcDataGenerator::default(),
62 dictionary_tracker: writer::DictionaryTracker::new(false),
63 }
64 }
65}
66
67impl FlightEncoder {
68 pub fn encode(&mut self, flight_message: FlightMessage) -> FlightData {
69 match flight_message {
70 FlightMessage::Schema(schema) => {
71 SchemaAsIpc::new(schema.arrow_schema(), &self.write_options).into()
72 }
73 FlightMessage::Recordbatch(recordbatch) => {
74 let (encoded_dictionaries, encoded_batch) = self
75 .data_gen
76 .encoded_batch(
77 recordbatch.df_record_batch(),
78 &mut self.dictionary_tracker,
79 &self.write_options,
80 )
81 .expect("DictionaryTracker configured above to not fail on replacement");
82
83 debug_assert_eq!(encoded_dictionaries.len(), 0);
89
90 encoded_batch.into()
91 }
92 FlightMessage::AffectedRows(rows) => {
93 let metadata = FlightMetadata {
94 affected_rows: Some(AffectedRows { value: rows as _ }),
95 metrics: None,
96 }
97 .encode_to_vec();
98 FlightData {
99 flight_descriptor: None,
100 data_header: build_none_flight_msg().into(),
101 app_metadata: metadata.into(),
102 data_body: ProstBytes::default(),
103 }
104 }
105 FlightMessage::Metrics(s) => {
106 let metadata = FlightMetadata {
107 affected_rows: None,
108 metrics: Some(Metrics {
109 metrics: s.as_bytes().to_vec(),
110 }),
111 }
112 .encode_to_vec();
113 FlightData {
114 flight_descriptor: None,
115 data_header: build_none_flight_msg().into(),
116 app_metadata: metadata.into(),
117 data_body: ProstBytes::default(),
118 }
119 }
120 }
121 }
122}
123
124#[derive(Default)]
125pub struct FlightDecoder {
126 schema: Option<SchemaRef>,
127}
128
129impl FlightDecoder {
130 pub fn try_decode(&mut self, flight_data: FlightData) -> Result<FlightMessage> {
131 let message = root_as_message(&flight_data.data_header).map_err(|e| {
132 InvalidFlightDataSnafu {
133 reason: e.to_string(),
134 }
135 .build()
136 })?;
137 match message.header_type() {
138 MessageHeader::NONE => {
139 let metadata = FlightMetadata::decode(flight_data.app_metadata)
140 .context(DecodeFlightDataSnafu)?;
141 if let Some(AffectedRows { value }) = metadata.affected_rows {
142 return Ok(FlightMessage::AffectedRows(value as _));
143 }
144 if let Some(Metrics { metrics }) = metadata.metrics {
145 return Ok(FlightMessage::Metrics(
146 String::from_utf8_lossy(&metrics).to_string(),
147 ));
148 }
149 InvalidFlightDataSnafu {
150 reason: "Expecting FlightMetadata have some meaningful content.",
151 }
152 .fail()
153 }
154 MessageHeader::Schema => {
155 let arrow_schema = ArrowSchema::try_from(&flight_data).map_err(|e| {
156 InvalidFlightDataSnafu {
157 reason: e.to_string(),
158 }
159 .build()
160 })?;
161 let schema =
162 Arc::new(Schema::try_from(arrow_schema).context(ConvertArrowSchemaSnafu)?);
163
164 self.schema = Some(schema.clone());
165
166 Ok(FlightMessage::Schema(schema))
167 }
168 MessageHeader::RecordBatch => {
169 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
170 reason: "Should have decoded schema first!",
171 })?;
172 let arrow_schema = schema.arrow_schema().clone();
173
174 let arrow_batch =
175 flight_data_to_arrow_batch(&flight_data, arrow_schema, &HashMap::new())
176 .map_err(|e| {
177 InvalidFlightDataSnafu {
178 reason: e.to_string(),
179 }
180 .build()
181 })?;
182 let recordbatch = RecordBatch::try_from_df_record_batch(schema, arrow_batch)
183 .context(CreateRecordBatchSnafu)?;
184 Ok(FlightMessage::Recordbatch(recordbatch))
185 }
186 other => {
187 let name = other.variant_name().unwrap_or("UNKNOWN");
188 InvalidFlightDataSnafu {
189 reason: format!("Unsupported FlightData type: {name}"),
190 }
191 .fail()
192 }
193 }
194 }
195}
196
197pub fn flight_messages_to_recordbatches(messages: Vec<FlightMessage>) -> Result<RecordBatches> {
198 if messages.is_empty() {
199 Ok(RecordBatches::empty())
200 } else {
201 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
202
203 let schema = match &messages[0] {
204 FlightMessage::Schema(schema) => schema.clone(),
205 _ => {
206 return InvalidFlightDataSnafu {
207 reason: "First Flight Message must be schema!",
208 }
209 .fail()
210 }
211 };
212
213 for message in messages.into_iter().skip(1) {
214 match message {
215 FlightMessage::Recordbatch(recordbatch) => recordbatches.push(recordbatch),
216 _ => {
217 return InvalidFlightDataSnafu {
218 reason: "Expect the following Flight Messages are all Recordbatches!",
219 }
220 .fail()
221 }
222 }
223 }
224
225 RecordBatches::try_new(schema, recordbatches).context(CreateRecordBatchSnafu)
226 }
227}
228
229fn build_none_flight_msg() -> Bytes {
230 let mut builder = FlatBufferBuilder::new();
231
232 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
233 message.add_version(arrow::ipc::MetadataVersion::V5);
234 message.add_header_type(MessageHeader::NONE);
235 message.add_bodyLength(0);
236
237 let data = message.finish();
238 builder.finish(data, None);
239
240 builder.finished_data().into()
241}
242
243#[cfg(test)]
244mod test {
245 use arrow_flight::utils::batches_to_flight_data;
246 use datatypes::arrow::datatypes::{DataType, Field};
247 use datatypes::prelude::ConcreteDataType;
248 use datatypes::schema::ColumnSchema;
249 use datatypes::vectors::Int32Vector;
250
251 use super::*;
252 use crate::Error;
253
254 #[test]
255 fn test_try_decode() {
256 let arrow_schema = ArrowSchema::new(vec![Field::new("n", DataType::Int32, true)]);
257 let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap());
258
259 let batch1 = RecordBatch::new(
260 schema.clone(),
261 vec![Arc::new(Int32Vector::from(vec![Some(1), None, Some(3)])) as _],
262 )
263 .unwrap();
264 let batch2 = RecordBatch::new(
265 schema.clone(),
266 vec![Arc::new(Int32Vector::from(vec![None, Some(5)])) as _],
267 )
268 .unwrap();
269
270 let flight_data = batches_to_flight_data(
271 &arrow_schema,
272 vec![
273 batch1.clone().into_df_record_batch(),
274 batch2.clone().into_df_record_batch(),
275 ],
276 )
277 .unwrap();
278 assert_eq!(flight_data.len(), 3);
279 let [d1, d2, d3] = flight_data.as_slice() else {
280 unreachable!()
281 };
282
283 let decoder = &mut FlightDecoder::default();
284 assert!(decoder.schema.is_none());
285
286 let result = decoder.try_decode(d2.clone());
287 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
288 assert!(result
289 .unwrap_err()
290 .to_string()
291 .contains("Should have decoded schema first!"));
292
293 let message = decoder.try_decode(d1.clone()).unwrap();
294 assert!(matches!(message, FlightMessage::Schema(_)));
295 let FlightMessage::Schema(decoded_schema) = message else {
296 unreachable!()
297 };
298 assert_eq!(decoded_schema, schema);
299
300 let _ = decoder.schema.as_ref().unwrap();
301
302 let message = decoder.try_decode(d2.clone()).unwrap();
303 assert!(matches!(message, FlightMessage::Recordbatch(_)));
304 let FlightMessage::Recordbatch(actual_batch) = message else {
305 unreachable!()
306 };
307 assert_eq!(actual_batch, batch1);
308
309 let message = decoder.try_decode(d3.clone()).unwrap();
310 assert!(matches!(message, FlightMessage::Recordbatch(_)));
311 let FlightMessage::Recordbatch(actual_batch) = message else {
312 unreachable!()
313 };
314 assert_eq!(actual_batch, batch2);
315 }
316
317 #[test]
318 fn test_flight_messages_to_recordbatches() {
319 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
320 "m",
321 ConcreteDataType::int32_datatype(),
322 true,
323 )]));
324 let batch1 = RecordBatch::new(
325 schema.clone(),
326 vec![Arc::new(Int32Vector::from(vec![Some(2), None, Some(4)])) as _],
327 )
328 .unwrap();
329 let batch2 = RecordBatch::new(
330 schema.clone(),
331 vec![Arc::new(Int32Vector::from(vec![None, Some(6)])) as _],
332 )
333 .unwrap();
334 let recordbatches =
335 RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
336
337 let m1 = FlightMessage::Schema(schema);
338 let m2 = FlightMessage::Recordbatch(batch1);
339 let m3 = FlightMessage::Recordbatch(batch2);
340
341 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
342 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
343 assert!(result
344 .unwrap_err()
345 .to_string()
346 .contains("First Flight Message must be schema!"));
347
348 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
349 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
350 assert!(result
351 .unwrap_err()
352 .to_string()
353 .contains("Expect the following Flight Messages are all Recordbatches!"));
354
355 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
356 assert_eq!(actual, recordbatches);
357 }
358}