sqlness_runner/
formatter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::fmt::Display;
17use std::sync::Arc;
18
19use client::{Output, OutputData, RecordBatches};
20use common_error::ext::ErrorExt;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::schema::{ColumnSchema, Schema};
24use datatypes::vectors::{StringVectorBuilder, VectorRef};
25use mysql::Row as MySqlRow;
26use tokio_postgres::SimpleQueryMessage as PgRow;
27
28use crate::client::MysqlSqlResult;
29
30/// A formatter for errors.
31pub struct ErrorFormatter<E: ErrorExt>(E);
32
33impl<E: ErrorExt> From<E> for ErrorFormatter<E> {
34    fn from(error: E) -> Self {
35        ErrorFormatter(error)
36    }
37}
38
39impl<E: ErrorExt> Display for ErrorFormatter<E> {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        let status_code = self.0.status_code();
42        let root_cause = self.0.output_msg();
43        write!(
44            f,
45            "Error: {}({status_code}), {root_cause}",
46            status_code as u32
47        )
48    }
49}
50
51/// A formatter for [`Output`].
52pub struct OutputFormatter(Output);
53
54impl From<Output> for OutputFormatter {
55    fn from(output: Output) -> Self {
56        OutputFormatter(output)
57    }
58}
59
60impl Display for OutputFormatter {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match &self.0.data {
63            OutputData::AffectedRows(rows) => {
64                write!(f, "Affected Rows: {rows}")
65            }
66            OutputData::RecordBatches(recordbatches) => {
67                let pretty = recordbatches.pretty_print().map_err(|e| e.to_string());
68                match pretty {
69                    Ok(s) => write!(f, "{s}"),
70                    Err(e) => {
71                        write!(f, "Failed to pretty format {recordbatches:?}, error: {e}")
72                    }
73                }
74            }
75            OutputData::Stream(_) => unreachable!(),
76        }
77    }
78}
79
80pub struct PostgresqlFormatter(Vec<PgRow>);
81
82impl From<Vec<PgRow>> for PostgresqlFormatter {
83    fn from(rows: Vec<PgRow>) -> Self {
84        PostgresqlFormatter(rows)
85    }
86}
87
88impl Display for PostgresqlFormatter {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        if self.0.is_empty() {
91            return f.write_fmt(format_args!("(Empty response)"));
92        }
93
94        if let PgRow::CommandComplete(affected_rows) = &self.0[0] {
95            return write!(
96                f,
97                "{}",
98                OutputFormatter(Output::new_with_affected_rows(*affected_rows as usize))
99            );
100        };
101
102        let Some(recordbatches) = build_recordbatches_from_postgres_rows(&self.0) else {
103            return Ok(());
104        };
105        write!(
106            f,
107            "{}",
108            OutputFormatter(Output::new_with_record_batches(recordbatches))
109        )
110    }
111}
112
113fn build_recordbatches_from_postgres_rows(rows: &[PgRow]) -> Option<RecordBatches> {
114    // create schema
115    let schema = match &rows[0] {
116        PgRow::RowDescription(desc) => Arc::new(Schema::new(
117            desc.iter()
118                .map(|column| {
119                    ColumnSchema::new(column.name(), ConcreteDataType::string_datatype(), true)
120                })
121                .collect(),
122        )),
123        _ => unreachable!(),
124    };
125    if schema.num_columns() == 0 {
126        return None;
127    }
128
129    // convert to string vectors
130    let mut columns: Vec<StringVectorBuilder> = (0..schema.num_columns())
131        .map(|_| StringVectorBuilder::with_capacity(schema.num_columns()))
132        .collect();
133    for row in rows.iter().skip(1) {
134        if let PgRow::Row(row) = row {
135            for (i, column) in columns.iter_mut().enumerate().take(schema.num_columns()) {
136                column.push(row.get(i));
137            }
138        }
139    }
140    let columns: Vec<VectorRef> = columns
141        .into_iter()
142        .map(|mut col| Arc::new(col.finish()) as VectorRef)
143        .collect();
144
145    // construct recordbatch
146    let recordbatches = RecordBatches::try_from_columns(schema, columns)
147        .expect("Failed to construct recordbatches from columns. Please check the schema.");
148    Some(recordbatches)
149}
150
151/// A formatter for [`MysqlSqlResult`].
152pub struct MysqlFormatter(MysqlSqlResult);
153
154impl From<MysqlSqlResult> for MysqlFormatter {
155    fn from(result: MysqlSqlResult) -> Self {
156        MysqlFormatter(result)
157    }
158}
159
160impl Display for MysqlFormatter {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match &self.0 {
163            MysqlSqlResult::AffectedRows(rows) => {
164                write!(f, "affected_rows: {rows}")
165            }
166            MysqlSqlResult::Rows(rows) => {
167                if rows.is_empty() {
168                    return f.write_fmt(format_args!("(Empty response)"));
169                }
170
171                let recordbatches = build_recordbatches_from_mysql_rows(rows);
172                write!(
173                    f,
174                    "{}",
175                    OutputFormatter(Output::new_with_record_batches(recordbatches))
176                )
177            }
178        }
179    }
180}
181
182pub fn build_recordbatches_from_mysql_rows(rows: &[MySqlRow]) -> RecordBatches {
183    // create schema
184    let head_column = &rows[0];
185    let head_binding = head_column.columns();
186    let names = head_binding
187        .iter()
188        .map(|column| column.name_str())
189        .collect::<Vec<Cow<str>>>();
190    let schema = Arc::new(Schema::new(
191        names
192            .iter()
193            .map(|name| {
194                ColumnSchema::new(name.to_string(), ConcreteDataType::string_datatype(), false)
195            })
196            .collect(),
197    ));
198
199    // convert to string vectors
200    let mut columns: Vec<StringVectorBuilder> = (0..schema.num_columns())
201        .map(|_| StringVectorBuilder::with_capacity(schema.num_columns()))
202        .collect();
203    for row in rows.iter() {
204        for (i, name) in names.iter().enumerate() {
205            columns[i].push(row.get::<String, &str>(name).as_deref());
206        }
207    }
208    let columns: Vec<VectorRef> = columns
209        .into_iter()
210        .map(|mut col| Arc::new(col.finish()) as VectorRef)
211        .collect();
212
213    // construct recordbatch
214    RecordBatches::try_from_columns(schema, columns)
215        .expect("Failed to construct recordbatches from columns. Please check the schema.")
216}