tests_fuzz/validator/
row.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use chrono::{DateTime as ChronoDateTime, NaiveDate, NaiveDateTime, Utc};
16use common_time::date::Date;
17use common_time::Timestamp;
18use datatypes::value::Value;
19use snafu::{ensure, ResultExt};
20use sqlx::mysql::MySqlRow;
21use sqlx::{Column, ColumnIndex, Database, MySqlPool, Row, TypeInfo, ValueRef};
22
23use crate::error::{self, Result};
24use crate::ir::insert_expr::{RowValue, RowValues};
25
26/// Asserts fetched_rows are equal to rows
27pub fn assert_eq<'a, DB>(
28    columns: &[crate::ir::Column],
29    fetched_rows: &'a [<DB as Database>::Row],
30    rows: &[RowValues],
31) -> Result<()>
32where
33    DB: Database,
34    usize: ColumnIndex<<DB as Database>::Row>,
35    bool: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
36    i8: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
37    i16: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
38    i32: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
39    i64: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
40    f32: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
41    f64: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
42    String: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
43    Vec<u8>: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
44    ChronoDateTime<Utc>: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
45    NaiveDateTime: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
46    NaiveDate: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
47{
48    ensure!(
49        fetched_rows.len() == rows.len(),
50        error::AssertSnafu {
51            reason: format!(
52                "Expected values length: {}, got: {}",
53                rows.len(),
54                fetched_rows.len(),
55            )
56        }
57    );
58
59    for (idx, fetched_row) in fetched_rows.iter().enumerate() {
60        let row = &rows[idx];
61
62        ensure!(
63            fetched_row.len() == row.len(),
64            error::AssertSnafu {
65                reason: format!(
66                    "Expected row length: {}, got: {}",
67                    row.len(),
68                    fetched_row.len(),
69                )
70            }
71        );
72
73        for (idx, value) in row.iter().enumerate() {
74            let fetched_value = if fetched_row.try_get_raw(idx).unwrap().is_null() {
75                RowValue::Value(Value::Null)
76            } else {
77                let value_type = fetched_row.column(idx).type_info().name();
78                match value_type {
79                    "BOOL" | "BOOLEAN" => RowValue::Value(Value::Boolean(
80                        fetched_row.try_get::<bool, usize>(idx).unwrap(),
81                    )),
82                    "TINYINT" => {
83                        RowValue::Value(Value::Int8(fetched_row.try_get::<i8, usize>(idx).unwrap()))
84                    }
85                    "SMALLINT" => RowValue::Value(Value::Int16(
86                        fetched_row.try_get::<i16, usize>(idx).unwrap(),
87                    )),
88                    "INT" => RowValue::Value(Value::Int32(
89                        fetched_row.try_get::<i32, usize>(idx).unwrap(),
90                    )),
91                    "BIGINT" => RowValue::Value(Value::Int64(
92                        fetched_row.try_get::<i64, usize>(idx).unwrap(),
93                    )),
94                    "FLOAT" => RowValue::Value(Value::Float32(datatypes::value::OrderedFloat(
95                        fetched_row.try_get::<f32, usize>(idx).unwrap(),
96                    ))),
97                    "DOUBLE" => RowValue::Value(Value::Float64(datatypes::value::OrderedFloat(
98                        fetched_row.try_get::<f64, usize>(idx).unwrap(),
99                    ))),
100                    "VARCHAR" | "CHAR" | "TEXT" => RowValue::Value(Value::String(
101                        fetched_row.try_get::<String, usize>(idx).unwrap().into(),
102                    )),
103                    "VARBINARY" | "BINARY" | "BLOB" => RowValue::Value(Value::Binary(
104                        fetched_row.try_get::<Vec<u8>, usize>(idx).unwrap().into(),
105                    )),
106                    "TIMESTAMP" => RowValue::Value(Value::Timestamp(
107                        Timestamp::from_chrono_datetime(
108                            fetched_row
109                                .try_get::<ChronoDateTime<Utc>, usize>(idx)
110                                .unwrap()
111                                .naive_utc(),
112                        )
113                        .unwrap(),
114                    )),
115                    "DATETIME" => RowValue::Value(Value::Timestamp(
116                        Timestamp::from_chrono_datetime(
117                            fetched_row
118                                .try_get::<ChronoDateTime<Utc>, usize>(idx)
119                                .unwrap()
120                                .naive_utc(),
121                        )
122                        .unwrap(),
123                    )),
124                    "DATE" => RowValue::Value(Value::Date(Date::from(
125                        fetched_row.try_get::<NaiveDate, usize>(idx).unwrap(),
126                    ))),
127                    _ => panic!("Unsupported type: {}", value_type),
128                }
129            };
130
131            let value = match value {
132                // In MySQL, boolean is stored as TINYINT(1)
133                RowValue::Value(Value::Boolean(v)) => RowValue::Value(Value::Int8(*v as i8)),
134                RowValue::Default => match columns[idx].default_value().unwrap().clone() {
135                    Value::Boolean(v) => RowValue::Value(Value::Int8(v as i8)),
136                    default_value => RowValue::Value(default_value),
137                },
138                _ => value.clone(),
139            };
140            ensure!(
141                value == fetched_value,
142                error::AssertSnafu {
143                    reason: format!("Expected value: {:?}, got: {:?}", value, fetched_value)
144                }
145            )
146        }
147    }
148
149    Ok(())
150}
151
152#[derive(Debug, sqlx::FromRow)]
153pub struct ValueCount {
154    pub count: i64,
155}
156
157pub async fn count_values(db: &MySqlPool, sql: &str) -> Result<ValueCount> {
158    sqlx::query_as::<_, ValueCount>(sql)
159        .fetch_one(db)
160        .await
161        .context(error::ExecuteQuerySnafu { sql })
162}
163
164/// Returns all [RowEntry] of the `table_name`.
165pub async fn fetch_values(db: &MySqlPool, sql: &str) -> Result<Vec<MySqlRow>> {
166    sqlx::query(sql)
167        .fetch_all(db)
168        .await
169        .context(error::ExecuteQuerySnafu { sql })
170}