servers/http/result/
null_result.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Write;
16
17use axum::http::{header, HeaderValue};
18use axum::response::{IntoResponse, Response};
19use common_error::status_code::StatusCode;
20use common_query::Output;
21use mime_guess::mime;
22use serde::{Deserialize, Serialize};
23
24use crate::http::header::{GREPTIME_DB_HEADER_EXECUTION_TIME, GREPTIME_DB_HEADER_FORMAT};
25use crate::http::result::error_result::ErrorResponse;
26use crate::http::{handler, GreptimeQueryOutput, HttpResponse, ResponseFormat};
27
28#[derive(Serialize, Deserialize, Debug)]
29enum Rows {
30    Affected(usize),
31    Queried(usize),
32}
33
34/// The null format is a simple text format that outputs the number of affected rows or queried rows
35#[derive(Serialize, Deserialize, Debug)]
36pub struct NullResponse {
37    rows: Rows,
38    execution_time_ms: u64,
39}
40
41impl NullResponse {
42    pub async fn from_output(outputs: Vec<crate::error::Result<Output>>) -> HttpResponse {
43        match handler::from_output(outputs).await {
44            Err(err) => HttpResponse::Error(err),
45            Ok((mut output, _)) => {
46                if output.len() > 1 {
47                    HttpResponse::Error(ErrorResponse::from_error_message(
48                        StatusCode::InvalidArguments,
49                        "cannot output multi-statements result in null format".to_string(),
50                    ))
51                } else {
52                    match output.pop() {
53                        Some(GreptimeQueryOutput::AffectedRows(rows)) => {
54                            HttpResponse::Null(NullResponse {
55                                rows: Rows::Affected(rows),
56                                execution_time_ms: 0,
57                            })
58                        }
59
60                        Some(GreptimeQueryOutput::Records(records)) => {
61                            HttpResponse::Null(NullResponse {
62                                rows: Rows::Queried(records.num_rows()),
63                                execution_time_ms: 0,
64                            })
65                        }
66                        _ => HttpResponse::Error(ErrorResponse::from_error_message(
67                            StatusCode::Unexpected,
68                            "unexpected output type".to_string(),
69                        )),
70                    }
71                }
72            }
73        }
74    }
75
76    /// Returns the number of rows affected or queried.
77    pub fn rows(&self) -> usize {
78        match &self.rows {
79            Rows::Affected(rows) => *rows,
80            Rows::Queried(rows) => *rows,
81        }
82    }
83
84    /// Consumes `self`, updates the execution time in milliseconds, and returns the updated instance.
85    pub(crate) fn with_execution_time(mut self, execution_time: u64) -> Self {
86        self.execution_time_ms = execution_time;
87        self
88    }
89}
90
91impl IntoResponse for NullResponse {
92    fn into_response(self) -> Response {
93        let mut body = String::new();
94        match self.rows {
95            Rows::Affected(rows) => {
96                let _ = writeln!(body, "{} rows affected.", rows);
97            }
98            Rows::Queried(rows) => {
99                let _ = writeln!(body, "{} rows in set.", rows);
100            }
101        }
102        let elapsed_secs = (self.execution_time_ms as f64) / 1000.0;
103        let _ = writeln!(body, "Elapsed: {:.3} sec.", elapsed_secs);
104
105        let mut resp = (
106            [(
107                header::CONTENT_TYPE,
108                HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
109            )],
110            body,
111        )
112            .into_response();
113        resp.headers_mut().insert(
114            &GREPTIME_DB_HEADER_FORMAT,
115            HeaderValue::from_static(ResponseFormat::Null.as_str()),
116        );
117        resp.headers_mut().insert(
118            &GREPTIME_DB_HEADER_EXECUTION_TIME,
119            HeaderValue::from(self.execution_time_ms),
120        );
121
122        resp
123    }
124}
125#[cfg(test)]
126mod tests {
127    use axum::body::to_bytes;
128    use axum::http;
129
130    use super::*;
131
132    #[tokio::test]
133    async fn test_into_response_format() {
134        let result = NullResponse {
135            rows: Rows::Queried(42),
136            execution_time_ms: 1234,
137        };
138        let response = result.into_response();
139
140        // Check status code
141        assert_eq!(response.status(), http::StatusCode::OK);
142
143        // Check headers
144        let headers = response.headers();
145        assert_eq!(
146            headers.get(axum::http::header::CONTENT_TYPE).unwrap(),
147            mime::TEXT_PLAIN_UTF_8.as_ref()
148        );
149        assert_eq!(
150            headers.get(&GREPTIME_DB_HEADER_FORMAT).unwrap(),
151            ResponseFormat::Null.as_str()
152        );
153        assert_eq!(
154            headers.get(&GREPTIME_DB_HEADER_EXECUTION_TIME).unwrap(),
155            "1234"
156        );
157
158        // Check body
159        let body_bytes = to_bytes(response.into_body(), 1024).await.unwrap();
160        let body = String::from_utf8(body_bytes.to_vec()).unwrap();
161        assert!(body.contains("42 rows in set."));
162        assert!(body.contains("Elapsed: 1.234 sec."));
163    }
164}