1use chrono::{DateTime as ChronoDateTime, NaiveDate, NaiveDateTime, Utc};
16use common_time::date::Date;
17use common_time::Timestamp;
18use datatypes::value::Value;
19use snafu::{ensure, ResultExt};
20use sqlx::mysql::MySqlRow;
21use sqlx::{Column, ColumnIndex, Database, MySqlPool, Row, TypeInfo, ValueRef};
22
23use crate::error::{self, Result};
24use crate::ir::insert_expr::{RowValue, RowValues};
25
26pub fn assert_eq<'a, DB>(
28 columns: &[crate::ir::Column],
29 fetched_rows: &'a [<DB as Database>::Row],
30 rows: &[RowValues],
31) -> Result<()>
32where
33 DB: Database,
34 usize: ColumnIndex<<DB as Database>::Row>,
35 bool: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
36 i8: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
37 i16: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
38 i32: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
39 i64: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
40 f32: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
41 f64: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
42 String: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
43 Vec<u8>: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
44 ChronoDateTime<Utc>: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
45 NaiveDateTime: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
46 NaiveDate: sqlx::Type<DB> + sqlx::Decode<'a, DB>,
47{
48 ensure!(
49 fetched_rows.len() == rows.len(),
50 error::AssertSnafu {
51 reason: format!(
52 "Expected values length: {}, got: {}",
53 rows.len(),
54 fetched_rows.len(),
55 )
56 }
57 );
58
59 for (idx, fetched_row) in fetched_rows.iter().enumerate() {
60 let row = &rows[idx];
61
62 ensure!(
63 fetched_row.len() == row.len(),
64 error::AssertSnafu {
65 reason: format!(
66 "Expected row length: {}, got: {}",
67 row.len(),
68 fetched_row.len(),
69 )
70 }
71 );
72
73 for (idx, value) in row.iter().enumerate() {
74 let fetched_value = if fetched_row.try_get_raw(idx).unwrap().is_null() {
75 RowValue::Value(Value::Null)
76 } else {
77 let value_type = fetched_row.column(idx).type_info().name();
78 match value_type {
79 "BOOL" | "BOOLEAN" => RowValue::Value(Value::Boolean(
80 fetched_row.try_get::<bool, usize>(idx).unwrap(),
81 )),
82 "TINYINT" => {
83 RowValue::Value(Value::Int8(fetched_row.try_get::<i8, usize>(idx).unwrap()))
84 }
85 "SMALLINT" => RowValue::Value(Value::Int16(
86 fetched_row.try_get::<i16, usize>(idx).unwrap(),
87 )),
88 "INT" => RowValue::Value(Value::Int32(
89 fetched_row.try_get::<i32, usize>(idx).unwrap(),
90 )),
91 "BIGINT" => RowValue::Value(Value::Int64(
92 fetched_row.try_get::<i64, usize>(idx).unwrap(),
93 )),
94 "FLOAT" => RowValue::Value(Value::Float32(datatypes::value::OrderedFloat(
95 fetched_row.try_get::<f32, usize>(idx).unwrap(),
96 ))),
97 "DOUBLE" => RowValue::Value(Value::Float64(datatypes::value::OrderedFloat(
98 fetched_row.try_get::<f64, usize>(idx).unwrap(),
99 ))),
100 "VARCHAR" | "CHAR" | "TEXT" => RowValue::Value(Value::String(
101 fetched_row.try_get::<String, usize>(idx).unwrap().into(),
102 )),
103 "VARBINARY" | "BINARY" | "BLOB" => RowValue::Value(Value::Binary(
104 fetched_row.try_get::<Vec<u8>, usize>(idx).unwrap().into(),
105 )),
106 "TIMESTAMP" => RowValue::Value(Value::Timestamp(
107 Timestamp::from_chrono_datetime(
108 fetched_row
109 .try_get::<ChronoDateTime<Utc>, usize>(idx)
110 .unwrap()
111 .naive_utc(),
112 )
113 .unwrap(),
114 )),
115 "DATETIME" => RowValue::Value(Value::Timestamp(
116 Timestamp::from_chrono_datetime(
117 fetched_row
118 .try_get::<ChronoDateTime<Utc>, usize>(idx)
119 .unwrap()
120 .naive_utc(),
121 )
122 .unwrap(),
123 )),
124 "DATE" => RowValue::Value(Value::Date(Date::from(
125 fetched_row.try_get::<NaiveDate, usize>(idx).unwrap(),
126 ))),
127 _ => panic!("Unsupported type: {}", value_type),
128 }
129 };
130
131 let value = match value {
132 RowValue::Value(Value::Boolean(v)) => RowValue::Value(Value::Int8(*v as i8)),
134 RowValue::Default => match columns[idx].default_value().unwrap().clone() {
135 Value::Boolean(v) => RowValue::Value(Value::Int8(v as i8)),
136 default_value => RowValue::Value(default_value),
137 },
138 _ => value.clone(),
139 };
140 ensure!(
141 value == fetched_value,
142 error::AssertSnafu {
143 reason: format!("Expected value: {:?}, got: {:?}", value, fetched_value)
144 }
145 )
146 }
147 }
148
149 Ok(())
150}
151
152#[derive(Debug, sqlx::FromRow)]
153pub struct ValueCount {
154 pub count: i64,
155}
156
157pub async fn count_values(db: &MySqlPool, sql: &str) -> Result<ValueCount> {
158 sqlx::query_as::<_, ValueCount>(sql)
159 .fetch_one(db)
160 .await
161 .context(error::ExecuteQuerySnafu { sql })
162}
163
164pub async fn fetch_values(db: &MySqlPool, sql: &str) -> Result<Vec<MySqlRow>> {
166 sqlx::query(sql)
167 .fetch_all(db)
168 .await
169 .context(error::ExecuteQuerySnafu { sql })
170}