common_test_util/
recordbatch.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use client::Database;
16use common_query::OutputData;
17use common_recordbatch::util;
18
19pub enum ExpectedOutput<'a> {
20    AffectedRows(usize),
21    QueryResult(&'a str),
22}
23
24pub async fn check_output_stream(output: OutputData, expected: &str) {
25    let recordbatches = match output {
26        OutputData::Stream(stream) => util::collect_batches(stream).await.unwrap(),
27        OutputData::RecordBatches(recordbatches) => recordbatches,
28        _ => unreachable!(),
29    };
30    let pretty_print = recordbatches.pretty_print().unwrap();
31    assert_eq!(pretty_print, expected, "actual: \n{}", pretty_print);
32}
33
34pub async fn execute_and_check_output(db: &Database, sql: &str, expected: ExpectedOutput<'_>) {
35    let output = db.sql(sql).await.unwrap();
36    let output = output.data;
37
38    match (&output, expected) {
39        (OutputData::AffectedRows(x), ExpectedOutput::AffectedRows(y)) => {
40            assert_eq!(
41                *x, y,
42                r#"
43expected: {y}
44actual: {x}
45"#
46            )
47        }
48        (OutputData::RecordBatches(_), ExpectedOutput::QueryResult(x))
49        | (OutputData::Stream(_), ExpectedOutput::QueryResult(x)) => {
50            check_output_stream(output, x).await
51        }
52        _ => panic!(),
53    }
54}