1use std::pin::Pin;
16use std::sync::Arc;
17
18use arrow::datatypes::Schema;
19use arrow_ipc::writer::{FileWriter, IpcWriteOptions};
20use arrow_ipc::CompressionType;
21use axum::http::{header, HeaderValue};
22use axum::response::{IntoResponse, Response};
23use common_error::status_code::StatusCode;
24use common_query::{Output, OutputData};
25use common_recordbatch::RecordBatchStream;
26use futures::StreamExt;
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29
30use crate::error::{self, Error};
31use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
32use crate::http::result::error_result::ErrorResponse;
33use crate::http::{HttpResponse, ResponseFormat};
34
35#[derive(Serialize, Deserialize, Debug)]
36pub struct ArrowResponse {
37 pub(crate) data: Vec<u8>,
38 pub(crate) execution_time_ms: u64,
39}
40
41async fn write_arrow_bytes(
42 mut recordbatches: Pin<Box<dyn RecordBatchStream + Send>>,
43 schema: &Arc<Schema>,
44 compression: Option<CompressionType>,
45) -> Result<Vec<u8>, Error> {
46 let mut bytes = Vec::new();
47 {
48 let options = IpcWriteOptions::default()
49 .try_with_compression(compression)
50 .context(error::ArrowSnafu)?;
51 let mut writer = FileWriter::try_new_with_options(&mut bytes, schema, options)
52 .context(error::ArrowSnafu)?;
53
54 while let Some(rb) = recordbatches.next().await {
55 let rb = rb.context(error::CollectRecordbatchSnafu)?;
56 writer
57 .write(&rb.into_df_record_batch())
58 .context(error::ArrowSnafu)?;
59 }
60
61 writer.finish().context(error::ArrowSnafu)?;
62 }
63
64 Ok(bytes)
65}
66
67fn compression_type(compression: Option<String>) -> Option<CompressionType> {
68 match compression
69 .map(|compression| compression.to_lowercase())
70 .as_deref()
71 {
72 Some("zstd") => Some(CompressionType::ZSTD),
73 Some("lz4") => Some(CompressionType::LZ4_FRAME),
74 _ => None,
75 }
76}
77
78impl ArrowResponse {
79 pub async fn from_output(
80 mut outputs: Vec<error::Result<Output>>,
81 compression: Option<String>,
82 ) -> HttpResponse {
83 if outputs.len() > 1 {
84 return HttpResponse::Error(ErrorResponse::from_error_message(
85 StatusCode::InvalidArguments,
86 "cannot output multi-statements result in arrow format".to_string(),
87 ));
88 }
89
90 let compression = compression_type(compression);
91
92 match outputs.pop() {
93 None => HttpResponse::Arrow(ArrowResponse {
94 data: vec![],
95 execution_time_ms: 0,
96 }),
97 Some(Ok(output)) => match output.data {
98 OutputData::AffectedRows(_) => HttpResponse::Arrow(ArrowResponse {
99 data: vec![],
100 execution_time_ms: 0,
101 }),
102 OutputData::RecordBatches(batches) => {
103 let schema = batches.schema();
104 match write_arrow_bytes(batches.as_stream(), schema.arrow_schema(), compression)
105 .await
106 {
107 Ok(payload) => HttpResponse::Arrow(ArrowResponse {
108 data: payload,
109 execution_time_ms: 0,
110 }),
111 Err(e) => HttpResponse::Error(ErrorResponse::from_error(e)),
112 }
113 }
114 OutputData::Stream(batches) => {
115 let schema = batches.schema();
116 match write_arrow_bytes(batches, schema.arrow_schema(), compression).await {
117 Ok(payload) => HttpResponse::Arrow(ArrowResponse {
118 data: payload,
119 execution_time_ms: 0,
120 }),
121 Err(e) => HttpResponse::Error(ErrorResponse::from_error(e)),
122 }
123 }
124 },
125 Some(Err(e)) => HttpResponse::Error(ErrorResponse::from_error(e)),
126 }
127 }
128
129 pub fn with_execution_time(mut self, execution_time: u64) -> Self {
130 self.execution_time_ms = execution_time;
131 self
132 }
133
134 pub fn execution_time_ms(&self) -> u64 {
135 self.execution_time_ms
136 }
137}
138
139impl IntoResponse for ArrowResponse {
140 fn into_response(self) -> Response {
141 let execution_time = self.execution_time_ms;
142 (
143 [
144 (
145 &header::CONTENT_TYPE,
146 HeaderValue::from_static("application/arrow"),
147 ),
148 (
149 &GREPTIME_DB_HEADER_FORMAT,
150 HeaderValue::from_static(ResponseFormat::Arrow.as_str()),
151 ),
152 (
153 &GREPTIME_DB_HEADER_EXECUTION_TIME,
154 HeaderValue::from(execution_time),
155 ),
156 ],
157 self.data,
158 )
159 .into_response()
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use std::io::Cursor;
166
167 use arrow_ipc::reader::FileReader;
168 use arrow_schema::DataType;
169 use common_recordbatch::{RecordBatch, RecordBatches};
170 use datatypes::prelude::*;
171 use datatypes::schema::{ColumnSchema, Schema};
172 use datatypes::vectors::{StringVector, UInt32Vector};
173
174 use super::*;
175
176 #[tokio::test]
177 async fn test_arrow_output() {
178 let column_schemas = vec![
179 ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
180 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
181 ];
182 let schema = Arc::new(Schema::new(column_schemas));
183 let columns: Vec<VectorRef> = vec![
184 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
185 Arc::new(StringVector::from(vec![
186 None,
187 Some("hello"),
188 Some("greptime"),
189 None,
190 ])),
191 ];
192
193 for compression in [None, Some("zstd".to_string()), Some("lz4".to_string())].into_iter() {
194 let recordbatch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
195 let recordbatches =
196 RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
197 let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
198
199 let http_resp = ArrowResponse::from_output(outputs, compression).await;
200 match http_resp {
201 HttpResponse::Arrow(resp) => {
202 let output = resp.data;
203 let mut reader =
204 FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
205 let schema = reader.schema();
206 assert_eq!(schema.fields[0].name(), "numbers");
207 assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
208 assert_eq!(schema.fields[1].name(), "strings");
209 assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
210
211 let rb = reader.next().unwrap().expect("read record batch failed");
212 assert_eq!(rb.num_columns(), 2);
213 assert_eq!(rb.num_rows(), 4);
214 }
215 HttpResponse::Error(e) => {
216 panic!("unexpected {:?}", e);
217 }
218 _ => unreachable!(),
219 }
220 }
221 }
222}