servers/http/result/
arrow_result.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::sync::Arc;
17
18use arrow::datatypes::Schema;
19use arrow_ipc::writer::{FileWriter, IpcWriteOptions};
20use arrow_ipc::CompressionType;
21use axum::http::{header, HeaderValue};
22use axum::response::{IntoResponse, Response};
23use common_error::status_code::StatusCode;
24use common_query::{Output, OutputData};
25use common_recordbatch::RecordBatchStream;
26use futures::StreamExt;
27use serde::{Deserialize, Serialize};
28use snafu::ResultExt;
29
30use crate::error::{self, Error};
31use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
32use crate::http::result::error_result::ErrorResponse;
33use crate::http::{HttpResponse, ResponseFormat};
34
35#[derive(Serialize, Deserialize, Debug)]
36pub struct ArrowResponse {
37    pub(crate) data: Vec<u8>,
38    pub(crate) execution_time_ms: u64,
39}
40
41async fn write_arrow_bytes(
42    mut recordbatches: Pin<Box<dyn RecordBatchStream + Send>>,
43    schema: &Arc<Schema>,
44    compression: Option<CompressionType>,
45) -> Result<Vec<u8>, Error> {
46    let mut bytes = Vec::new();
47    {
48        let options = IpcWriteOptions::default()
49            .try_with_compression(compression)
50            .context(error::ArrowSnafu)?;
51        let mut writer = FileWriter::try_new_with_options(&mut bytes, schema, options)
52            .context(error::ArrowSnafu)?;
53
54        while let Some(rb) = recordbatches.next().await {
55            let rb = rb.context(error::CollectRecordbatchSnafu)?;
56            writer
57                .write(&rb.into_df_record_batch())
58                .context(error::ArrowSnafu)?;
59        }
60
61        writer.finish().context(error::ArrowSnafu)?;
62    }
63
64    Ok(bytes)
65}
66
67fn compression_type(compression: Option<String>) -> Option<CompressionType> {
68    match compression
69        .map(|compression| compression.to_lowercase())
70        .as_deref()
71    {
72        Some("zstd") => Some(CompressionType::ZSTD),
73        Some("lz4") => Some(CompressionType::LZ4_FRAME),
74        _ => None,
75    }
76}
77
78impl ArrowResponse {
79    pub async fn from_output(
80        mut outputs: Vec<error::Result<Output>>,
81        compression: Option<String>,
82    ) -> HttpResponse {
83        if outputs.len() > 1 {
84            return HttpResponse::Error(ErrorResponse::from_error_message(
85                StatusCode::InvalidArguments,
86                "cannot output multi-statements result in arrow format".to_string(),
87            ));
88        }
89
90        let compression = compression_type(compression);
91
92        match outputs.pop() {
93            None => HttpResponse::Arrow(ArrowResponse {
94                data: vec![],
95                execution_time_ms: 0,
96            }),
97            Some(Ok(output)) => match output.data {
98                OutputData::AffectedRows(_) => HttpResponse::Arrow(ArrowResponse {
99                    data: vec![],
100                    execution_time_ms: 0,
101                }),
102                OutputData::RecordBatches(batches) => {
103                    let schema = batches.schema();
104                    match write_arrow_bytes(batches.as_stream(), schema.arrow_schema(), compression)
105                        .await
106                    {
107                        Ok(payload) => HttpResponse::Arrow(ArrowResponse {
108                            data: payload,
109                            execution_time_ms: 0,
110                        }),
111                        Err(e) => HttpResponse::Error(ErrorResponse::from_error(e)),
112                    }
113                }
114                OutputData::Stream(batches) => {
115                    let schema = batches.schema();
116                    match write_arrow_bytes(batches, schema.arrow_schema(), compression).await {
117                        Ok(payload) => HttpResponse::Arrow(ArrowResponse {
118                            data: payload,
119                            execution_time_ms: 0,
120                        }),
121                        Err(e) => HttpResponse::Error(ErrorResponse::from_error(e)),
122                    }
123                }
124            },
125            Some(Err(e)) => HttpResponse::Error(ErrorResponse::from_error(e)),
126        }
127    }
128
129    pub fn with_execution_time(mut self, execution_time: u64) -> Self {
130        self.execution_time_ms = execution_time;
131        self
132    }
133
134    pub fn execution_time_ms(&self) -> u64 {
135        self.execution_time_ms
136    }
137}
138
139impl IntoResponse for ArrowResponse {
140    fn into_response(self) -> Response {
141        let execution_time = self.execution_time_ms;
142        (
143            [
144                (
145                    &header::CONTENT_TYPE,
146                    HeaderValue::from_static("application/arrow"),
147                ),
148                (
149                    &GREPTIME_DB_HEADER_FORMAT,
150                    HeaderValue::from_static(ResponseFormat::Arrow.as_str()),
151                ),
152                (
153                    &GREPTIME_DB_HEADER_EXECUTION_TIME,
154                    HeaderValue::from(execution_time),
155                ),
156            ],
157            self.data,
158        )
159            .into_response()
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use std::io::Cursor;
166
167    use arrow_ipc::reader::FileReader;
168    use arrow_schema::DataType;
169    use common_recordbatch::{RecordBatch, RecordBatches};
170    use datatypes::prelude::*;
171    use datatypes::schema::{ColumnSchema, Schema};
172    use datatypes::vectors::{StringVector, UInt32Vector};
173
174    use super::*;
175
176    #[tokio::test]
177    async fn test_arrow_output() {
178        let column_schemas = vec![
179            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
180            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
181        ];
182        let schema = Arc::new(Schema::new(column_schemas));
183        let columns: Vec<VectorRef> = vec![
184            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
185            Arc::new(StringVector::from(vec![
186                None,
187                Some("hello"),
188                Some("greptime"),
189                None,
190            ])),
191        ];
192
193        for compression in [None, Some("zstd".to_string()), Some("lz4".to_string())].into_iter() {
194            let recordbatch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
195            let recordbatches =
196                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
197            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
198
199            let http_resp = ArrowResponse::from_output(outputs, compression).await;
200            match http_resp {
201                HttpResponse::Arrow(resp) => {
202                    let output = resp.data;
203                    let mut reader =
204                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
205                    let schema = reader.schema();
206                    assert_eq!(schema.fields[0].name(), "numbers");
207                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
208                    assert_eq!(schema.fields[1].name(), "strings");
209                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
210
211                    let rb = reader.next().unwrap().expect("read record batch failed");
212                    assert_eq!(rb.num_columns(), 2);
213                    assert_eq!(rb.num_rows(), 4);
214                }
215                HttpResponse::Error(e) => {
216                    panic!("unexpected {:?}", e);
217                }
218                _ => unreachable!(),
219            }
220        }
221    }
222}