Skip to main content

servers/mysql/
writer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io;
16use std::time::Duration;
17
18use arrow::array::{Array, AsArray};
19use arrow::datatypes::{
20    Date32Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
21    Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, UInt8Type,
22    UInt16Type, UInt32Type, UInt64Type,
23};
24use arrow_schema::{DataType, IntervalUnit};
25use common_decimal::Decimal128;
26use common_error::ext::ErrorExt;
27use common_error::status_code::StatusCode;
28use common_query::{Output, OutputData};
29use common_recordbatch::{RecordBatch, SendableRecordBatchStream};
30use common_telemetry::{debug, error};
31use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth};
32use datafusion_common::ScalarValue;
33use datatypes::prelude::ConcreteDataType;
34use datatypes::schema::SchemaRef;
35use datatypes::types::jsonb_to_string;
36use futures::StreamExt;
37use opensrv_mysql::{
38    Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter,
39};
40use session::SessionRef;
41use session::context::QueryContextRef;
42use snafu::prelude::*;
43use tokio::io::AsyncWrite;
44
45use crate::error::{self, ConvertSqlValueSnafu, DataFusionSnafu, NotSupportedSnafu, Result};
46use crate::metrics::*;
47
48/// Try to write multiple output to the writer if possible.
49pub async fn write_output<W: AsyncWrite + Send + Sync + Unpin>(
50    mut writer: QueryResultWriter<'_, W>,
51    query_context: QueryContextRef,
52    session: SessionRef,
53    outputs: Vec<Result<Output>>,
54) -> Result<()> {
55    if let Some(warning) = query_context.warning() {
56        session.add_warning(warning);
57    }
58
59    enum Response {
60        ResultSet {
61            columns: Vec<Column>,
62            stream: SendableRecordBatchStream,
63        },
64        AffectedRows(usize),
65    }
66
67    let mut responses = Vec::with_capacity(outputs.len());
68    for output in outputs {
69        match output {
70            Ok(x) => {
71                let output = match x.data {
72                    OutputData::Stream(stream) => either::Left(stream),
73                    OutputData::RecordBatches(record_batches) => {
74                        either::Left(record_batches.as_stream())
75                    }
76                    OutputData::AffectedRows(rows) => either::Right(rows),
77                };
78                responses.push(match output {
79                    either::Left(stream) => {
80                        let schema = stream.schema();
81                        let columns = match create_mysql_column_def(&schema) {
82                            Ok(columns) => columns,
83                            Err(e) => {
84                                MysqlResultWriter::write_query_error(
85                                    e,
86                                    writer,
87                                    query_context.clone(),
88                                )
89                                .await?;
90                                return Ok(());
91                            }
92                        };
93                        Response::ResultSet { columns, stream }
94                    }
95                    either::Right(rows) => Response::AffectedRows(rows),
96                });
97            }
98            Err(e) => {
99                MysqlResultWriter::write_query_error(e, writer, query_context.clone()).await?;
100                return Ok(());
101            }
102        }
103    }
104
105    for response in &mut responses {
106        writer = match response {
107            Response::ResultSet { columns, stream } => {
108                let mut row_writer = writer.start(columns).await?;
109                while let Some(record_batch) = stream.next().await {
110                    match record_batch {
111                        Ok(record_batch) => {
112                            if let Err(e) = MysqlResultWriter::write_recordbatch(
113                                &mut row_writer,
114                                record_batch,
115                                query_context.clone(),
116                            )
117                            .await
118                            {
119                                let (kind, err) = handle_err(e, query_context);
120                                row_writer.finish_error(kind, &err.as_bytes()).await?;
121                                return Ok(());
122                            }
123                        }
124                        Err(e) => {
125                            let (kind, err) = handle_err(e, query_context);
126                            row_writer.finish_error(kind, &err.as_bytes()).await?;
127                            return Ok(());
128                        }
129                    }
130                }
131                row_writer.finish_one().await?
132            }
133            Response::AffectedRows(rows) => {
134                MysqlResultWriter::write_affected_rows(writer, *rows, &session).await?
135            }
136        }
137    }
138
139    writer.no_more_results().await?;
140    Ok(())
141}
142
143/// Handle GreptimeDB error, convert it to MySQL error
144pub fn handle_err(e: impl ErrorExt, query_ctx: QueryContextRef) -> (ErrorKind, String) {
145    let status_code = e.status_code();
146    let kind = mysql_error_kind(&status_code);
147
148    if status_code.should_log_error() {
149        let root_error = e.root_cause().unwrap_or(&e);
150        error!(e; "Failed to handle mysql query, code: {}, error: {}, db: {}", status_code, root_error.to_string(), query_ctx.get_db_string());
151    } else {
152        debug!(
153            "Failed to handle mysql query, code: {}, db: {}, error: {:?}",
154            status_code,
155            query_ctx.get_db_string(),
156            e
157        );
158    };
159    let msg = e.output_msg();
160    // Inline the status code to output message for MySQL
161    let err_msg = format!("({status_code}): {msg}");
162
163    (kind, err_msg)
164}
165
166struct MysqlResultWriter;
167
168impl MysqlResultWriter {
169    async fn write_affected_rows<'a, W: AsyncWrite + Unpin>(
170        w: QueryResultWriter<'a, W>,
171        rows: usize,
172        session: &SessionRef,
173    ) -> io::Result<QueryResultWriter<'a, W>> {
174        let warnings = session.warnings_count() as u16;
175
176        let next_writer = w
177            .complete_one(OkResponse {
178                affected_rows: rows as u64,
179                warnings,
180                ..Default::default()
181            })
182            .await?;
183        Ok(next_writer)
184    }
185
186    async fn write_recordbatch<W: AsyncWrite + Unpin>(
187        row_writer: &mut RowWriter<'_, '_, W>,
188        record_batch: RecordBatch,
189        query_context: QueryContextRef,
190    ) -> Result<()> {
191        let schema = record_batch.schema.clone();
192        let record_batch = record_batch.into_df_record_batch();
193        for i in 0..record_batch.num_rows() {
194            for (j, column) in record_batch.columns().iter().enumerate() {
195                if column.is_null(i) {
196                    row_writer.write_col(None::<u8>)?;
197                    continue;
198                }
199
200                match column.data_type() {
201                    DataType::Null => {
202                        row_writer.write_col(None::<u8>)?;
203                    }
204                    DataType::Boolean => {
205                        let array = column.as_boolean();
206                        row_writer.write_col(array.value(i) as i8)?;
207                    }
208                    DataType::UInt8 => {
209                        let array = column.as_primitive::<UInt8Type>();
210                        row_writer.write_col(array.value(i))?;
211                    }
212                    DataType::UInt16 => {
213                        let array = column.as_primitive::<UInt16Type>();
214                        row_writer.write_col(array.value(i))?;
215                    }
216                    DataType::UInt32 => {
217                        let array = column.as_primitive::<UInt32Type>();
218                        row_writer.write_col(array.value(i))?;
219                    }
220                    DataType::UInt64 => {
221                        let array = column.as_primitive::<UInt64Type>();
222                        row_writer.write_col(array.value(i))?;
223                    }
224                    DataType::Int8 => {
225                        let array = column.as_primitive::<Int8Type>();
226                        row_writer.write_col(array.value(i))?;
227                    }
228                    DataType::Int16 => {
229                        let array = column.as_primitive::<Int16Type>();
230                        row_writer.write_col(array.value(i))?;
231                    }
232                    DataType::Int32 => {
233                        let array = column.as_primitive::<Int32Type>();
234                        row_writer.write_col(array.value(i))?;
235                    }
236                    DataType::Int64 => {
237                        let array = column.as_primitive::<Int64Type>();
238                        row_writer.write_col(array.value(i))?;
239                    }
240                    DataType::Float32 => {
241                        let array = column.as_primitive::<Float32Type>();
242                        row_writer.write_col(array.value(i))?;
243                    }
244                    DataType::Float64 => {
245                        let array = column.as_primitive::<Float64Type>();
246                        row_writer.write_col(array.value(i))?;
247                    }
248                    DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => {
249                        let v = datatypes::arrow_array::string_array_value(column, i);
250                        row_writer.write_col(v)?;
251                    }
252                    DataType::Binary | DataType::BinaryView | DataType::LargeBinary => {
253                        let v = datatypes::arrow_array::binary_array_value(column, i);
254                        if let ConcreteDataType::Json(_) = &schema.column_schemas()[j].data_type {
255                            let s = jsonb_to_string(v).context(ConvertSqlValueSnafu)?;
256                            row_writer.write_col(s)?;
257                        } else {
258                            row_writer.write_col(v)?;
259                        }
260                    }
261                    DataType::Date32 => {
262                        let array = column.as_primitive::<Date32Type>();
263                        let v = Date::new(array.value(i));
264                        row_writer.write_col(v.to_chrono_date())?;
265                    }
266                    DataType::Timestamp(_, _) => {
267                        let v = datatypes::arrow_array::timestamp_array_value(column, i);
268                        let v = v.to_chrono_datetime_with_timezone(Some(&query_context.timezone()));
269                        row_writer.write_col(v)?;
270                    }
271                    DataType::Interval(interval_unit) => match interval_unit {
272                        IntervalUnit::YearMonth => {
273                            let array = column.as_primitive::<IntervalYearMonthType>();
274                            let v: IntervalYearMonth = array.value(i).into();
275                            row_writer.write_col(v.to_iso8601_string())?;
276                        }
277                        IntervalUnit::DayTime => {
278                            let array = column.as_primitive::<IntervalDayTimeType>();
279                            let v: IntervalDayTime = array.value(i).into();
280                            row_writer.write_col(v.to_iso8601_string())?;
281                        }
282                        IntervalUnit::MonthDayNano => {
283                            let array = column.as_primitive::<IntervalMonthDayNanoType>();
284                            let v: IntervalMonthDayNano = array.value(i).into();
285                            row_writer.write_col(v.to_iso8601_string())?;
286                        }
287                    },
288                    DataType::Duration(_) => {
289                        let v: Duration =
290                            datatypes::arrow_array::duration_array_value(column, i).into();
291                        row_writer.write_col(v)?;
292                    }
293                    DataType::List(_) | DataType::Struct(_) => {
294                        let v = ScalarValue::try_from_array(column, i).context(DataFusionSnafu)?;
295                        row_writer.write_col(v.to_string())?;
296                    }
297                    DataType::Time32(_) | DataType::Time64(_) => {
298                        let time = datatypes::arrow_array::time_array_value(column, i);
299                        let v = time.to_timezone_aware_string(Some(&query_context.timezone()));
300                        row_writer.write_col(v)?;
301                    }
302                    DataType::Decimal128(precision, scale) => {
303                        let array = column.as_primitive::<Decimal128Type>();
304                        let v = Decimal128::new(array.value(i), *precision, *scale);
305                        row_writer.write_col(v.to_string())?;
306                    }
307                    _ => {
308                        return NotSupportedSnafu {
309                            feat: format!("convert {} to MySQL value", column.data_type()),
310                        }
311                        .fail();
312                    }
313                }
314            }
315            row_writer.end_row().await?;
316        }
317        Ok(())
318    }
319
320    async fn write_query_error<'a, W: AsyncWrite + Unpin>(
321        error: impl ErrorExt,
322        w: QueryResultWriter<'a, W>,
323        query_context: QueryContextRef,
324    ) -> io::Result<()> {
325        METRIC_ERROR_COUNTER
326            .with_label_values(&[METRIC_ERROR_COUNTER_LABEL_MYSQL])
327            .inc();
328
329        let (kind, err) = handle_err(error, query_context);
330        debug!("Write query error, kind: {:?}, err: {}", kind, err);
331        w.error(kind, err.as_bytes()).await?;
332        Ok(())
333    }
334}
335
336pub(crate) fn create_mysql_column(
337    data_type: &ConcreteDataType,
338    column_name: &str,
339) -> Result<Column> {
340    let column_type = match data_type {
341        ConcreteDataType::Null(_) => Ok(ColumnType::MYSQL_TYPE_NULL),
342        ConcreteDataType::Boolean(_) | ConcreteDataType::Int8(_) | ConcreteDataType::UInt8(_) => {
343            Ok(ColumnType::MYSQL_TYPE_TINY)
344        }
345        ConcreteDataType::Int16(_) | ConcreteDataType::UInt16(_) => {
346            Ok(ColumnType::MYSQL_TYPE_SHORT)
347        }
348        ConcreteDataType::Int32(_) | ConcreteDataType::UInt32(_) => Ok(ColumnType::MYSQL_TYPE_LONG),
349        ConcreteDataType::Int64(_) | ConcreteDataType::UInt64(_) => {
350            Ok(ColumnType::MYSQL_TYPE_LONGLONG)
351        }
352        ConcreteDataType::Float32(_) => Ok(ColumnType::MYSQL_TYPE_FLOAT),
353        ConcreteDataType::Float64(_) => Ok(ColumnType::MYSQL_TYPE_DOUBLE),
354        ConcreteDataType::Binary(_) | ConcreteDataType::String(_) => {
355            Ok(ColumnType::MYSQL_TYPE_VARCHAR)
356        }
357        ConcreteDataType::Timestamp(_) => Ok(ColumnType::MYSQL_TYPE_TIMESTAMP),
358        ConcreteDataType::Time(_) => Ok(ColumnType::MYSQL_TYPE_TIME),
359        ConcreteDataType::Date(_) => Ok(ColumnType::MYSQL_TYPE_DATE),
360        ConcreteDataType::Interval(_) => Ok(ColumnType::MYSQL_TYPE_VARCHAR),
361        ConcreteDataType::Duration(_) => Ok(ColumnType::MYSQL_TYPE_TIME),
362        ConcreteDataType::Decimal128(_) => Ok(ColumnType::MYSQL_TYPE_DECIMAL),
363        ConcreteDataType::Json(_) => Ok(ColumnType::MYSQL_TYPE_JSON),
364        ConcreteDataType::Vector(_) => Ok(ColumnType::MYSQL_TYPE_BLOB),
365        ConcreteDataType::List(_) => Ok(ColumnType::MYSQL_TYPE_VARCHAR),
366        ConcreteDataType::Struct(_) => Ok(ColumnType::MYSQL_TYPE_VARCHAR),
367        _ => error::UnsupportedDataTypeSnafu {
368            data_type,
369            reason: "not implemented",
370        }
371        .fail(),
372    };
373    let mut colflags = ColumnFlags::empty();
374    match data_type {
375        ConcreteDataType::UInt16(_)
376        | ConcreteDataType::UInt8(_)
377        | ConcreteDataType::UInt32(_)
378        | ConcreteDataType::UInt64(_) => colflags |= ColumnFlags::UNSIGNED_FLAG,
379        _ => {}
380    };
381    column_type.map(|column_type| Column {
382        column: column_name.to_string(),
383        coltype: column_type,
384        // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server
385        //   implementation, will revisit them again in the future.
386        table: String::default(),
387        collen: 0, // 0 means "use default".
388        colflags,
389    })
390}
391
392/// Creates MySQL columns definition from our column schema.
393pub fn create_mysql_column_def(schema: &SchemaRef) -> Result<Vec<Column>> {
394    schema
395        .column_schemas()
396        .iter()
397        .map(|column_schema| create_mysql_column(&column_schema.data_type, &column_schema.name))
398        .collect()
399}
400
401fn mysql_error_kind(status_code: &StatusCode) -> ErrorKind {
402    match status_code {
403        StatusCode::Success => ErrorKind::ER_YES,
404        StatusCode::Unknown | StatusCode::External => ErrorKind::ER_UNKNOWN_ERROR,
405        StatusCode::Unsupported => ErrorKind::ER_NOT_SUPPORTED_YET,
406        StatusCode::Cancelled | StatusCode::DeadlineExceeded => ErrorKind::ER_QUERY_INTERRUPTED,
407        StatusCode::RuntimeResourcesExhausted => ErrorKind::ER_OUT_OF_RESOURCES,
408        StatusCode::InvalidSyntax => ErrorKind::ER_SYNTAX_ERROR,
409        StatusCode::RegionAlreadyExists | StatusCode::TableAlreadyExists => {
410            ErrorKind::ER_TABLE_EXISTS_ERROR
411        }
412        StatusCode::RegionNotFound | StatusCode::TableNotFound => ErrorKind::ER_NO_SUCH_TABLE,
413        StatusCode::RegionReadonly => ErrorKind::ER_READ_ONLY_MODE,
414        StatusCode::DatabaseNotFound => ErrorKind::ER_WRONG_DB_NAME,
415        StatusCode::UserNotFound => ErrorKind::ER_NO_SUCH_USER,
416        StatusCode::UnsupportedPasswordType => ErrorKind::ER_PASSWORD_FORMAT,
417        StatusCode::PermissionDenied | StatusCode::AccessDenied => {
418            ErrorKind::ER_ACCESS_DENIED_ERROR
419        }
420        StatusCode::UserPasswordMismatch => ErrorKind::ER_DBACCESS_DENIED_ERROR,
421        StatusCode::InvalidAuthHeader | StatusCode::AuthHeaderNotFound => {
422            ErrorKind::ER_NOT_SUPPORTED_AUTH_MODE
423        }
424        StatusCode::Unexpected
425        | StatusCode::Internal
426        | StatusCode::IllegalState
427        | StatusCode::PlanQuery
428        | StatusCode::EngineExecuteQuery
429        | StatusCode::RegionNotReady
430        | StatusCode::RegionBusy
431        | StatusCode::TableUnavailable
432        | StatusCode::StorageUnavailable
433        | StatusCode::RequestOutdated => ErrorKind::ER_INTERNAL_ERROR,
434        StatusCode::InvalidArguments => ErrorKind::ER_WRONG_ARGUMENTS,
435        StatusCode::TableColumnNotFound => ErrorKind::ER_BAD_FIELD_ERROR,
436        StatusCode::TableColumnExists => ErrorKind::ER_DUP_FIELDNAME,
437        StatusCode::DatabaseAlreadyExists => ErrorKind::ER_DB_CREATE_EXISTS,
438        StatusCode::RateLimited => ErrorKind::ER_TOO_MANY_CONCURRENT_TRXS,
439        StatusCode::FlowAlreadyExists => ErrorKind::ER_TABLE_EXISTS_ERROR,
440        StatusCode::FlowNotFound => ErrorKind::ER_NO_SUCH_TABLE,
441        StatusCode::TriggerAlreadyExists => ErrorKind::ER_TABLE_EXISTS_ERROR,
442        StatusCode::TriggerNotFound => ErrorKind::ER_NO_SUCH_TABLE,
443        StatusCode::Suspended => ErrorKind::ER_SERVER_SHUTDOWN,
444    }
445}