servers/postgres/
types.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod bytea;
16mod datetime;
17mod error;
18mod interval;
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, AsArray};
24use arrow::datatypes::{
25    Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type,
26    Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
27    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28    TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
29    TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30};
31use arrow_schema::{DataType, IntervalUnit, TimeUnit};
32use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
33use common_decimal::Decimal128;
34use common_recordbatch::RecordBatch;
35use common_time::time::Time;
36use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp};
37use datafusion_common::ScalarValue;
38use datafusion_expr::LogicalPlan;
39use datatypes::arrow::datatypes::DataType as ArrowDataType;
40use datatypes::json::JsonStructureSettings;
41use datatypes::prelude::{ConcreteDataType, Value};
42use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
43use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
44use datatypes::value::StructValue;
45use pgwire::api::Type;
46use pgwire::api::portal::{Format, Portal};
47use pgwire::api::results::{DataRowEncoder, FieldInfo};
48use pgwire::error::{PgWireError, PgWireResult};
49use pgwire::messages::data::DataRow;
50use session::context::QueryContextRef;
51use session::session_config::PGByteaOutputValue;
52use snafu::ResultExt;
53
54use self::bytea::{EscapeOutputBytea, HexOutputBytea};
55use self::datetime::{StylingDate, StylingDateTime};
56pub use self::error::{PgErrorCode, PgErrorSeverity};
57use self::interval::PgInterval;
58use crate::SqlPlan;
59use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
60use crate::postgres::utils::convert_err;
61
62pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
63    origin
64        .column_schemas()
65        .iter()
66        .enumerate()
67        .map(|(idx, col)| {
68            Ok(FieldInfo::new(
69                col.name.clone(),
70                None,
71                None,
72                type_gt_to_pg(&col.data_type)?,
73                field_formats.format_for(idx),
74            ))
75        })
76        .collect::<Result<Vec<FieldInfo>>>()
77}
78
79/// this function will encode greptime's `StructValue` into PostgreSQL jsonb type
80///
81/// Note that greptimedb has different types of StructValue for storing json data,
82/// based on policy defined in `JsonStructureSettings`. But here the `StructValue`
83/// should be fully structured.
84///
85/// there are alternatives like records, arrays, etc. but there are also limitations:
86/// records: there is no support for include keys
87/// arrays: element in array must be the same type
88fn encode_struct(
89    _query_ctx: &QueryContextRef,
90    struct_value: StructValue,
91    builder: &mut DataRowEncoder,
92) -> PgWireResult<()> {
93    let encoding_setting = JsonStructureSettings::Structured(None);
94    let json_value = encoding_setting
95        .decode(Value::Struct(struct_value))
96        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
97
98    builder.encode_field(&json_value)
99}
100
101fn encode_array(
102    query_ctx: &QueryContextRef,
103    array: ArrayRef,
104    builder: &mut DataRowEncoder,
105) -> PgWireResult<()> {
106    macro_rules! encode_primitive_array {
107        ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
108            let array = $array.iter().collect::<Vec<Option<$data_type>>>();
109            if array
110                .iter()
111                .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
112            {
113                builder.encode_field(
114                    &array
115                        .into_iter()
116                        .map(|x| x.map(|i| i as $lower_type))
117                        .collect::<Vec<Option<$lower_type>>>(),
118                )
119            } else {
120                builder.encode_field(
121                    &array
122                        .into_iter()
123                        .map(|x| x.map(|i| i as $upper_type))
124                        .collect::<Vec<Option<$upper_type>>>(),
125                )
126            }
127        }};
128    }
129
130    match array.data_type() {
131        DataType::Boolean => {
132            let array = array.as_boolean();
133            let array = array.iter().collect::<Vec<_>>();
134            builder.encode_field(&array)
135        }
136        DataType::Int8 => {
137            let array = array.as_primitive::<Int8Type>();
138            let array = array.iter().collect::<Vec<_>>();
139            builder.encode_field(&array)
140        }
141        DataType::Int16 => {
142            let array = array.as_primitive::<Int16Type>();
143            let array = array.iter().collect::<Vec<_>>();
144            builder.encode_field(&array)
145        }
146        DataType::Int32 => {
147            let array = array.as_primitive::<Int32Type>();
148            let array = array.iter().collect::<Vec<_>>();
149            builder.encode_field(&array)
150        }
151        DataType::Int64 => {
152            let array = array.as_primitive::<Int64Type>();
153            let array = array.iter().collect::<Vec<_>>();
154            builder.encode_field(&array)
155        }
156        DataType::UInt8 => {
157            let array = array.as_primitive::<UInt8Type>();
158            encode_primitive_array!(array, u8, i8, i16)
159        }
160        DataType::UInt16 => {
161            let array = array.as_primitive::<UInt16Type>();
162            encode_primitive_array!(array, u16, i16, i32)
163        }
164        DataType::UInt32 => {
165            let array = array.as_primitive::<UInt32Type>();
166            encode_primitive_array!(array, u32, i32, i64)
167        }
168        DataType::UInt64 => {
169            let array = array.as_primitive::<UInt64Type>();
170            let array = array.iter().collect::<Vec<_>>();
171            if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
172                builder.encode_field(
173                    &array
174                        .into_iter()
175                        .map(|x| x.map(|i| i as i64))
176                        .collect::<Vec<Option<i64>>>(),
177                )
178            } else {
179                builder.encode_field(
180                    &array
181                        .into_iter()
182                        .map(|x| x.map(|i| i.to_string()))
183                        .collect::<Vec<_>>(),
184                )
185            }
186        }
187        DataType::Float32 => {
188            let array = array.as_primitive::<Float32Type>();
189            let array = array.iter().collect::<Vec<_>>();
190            builder.encode_field(&array)
191        }
192        DataType::Float64 => {
193            let array = array.as_primitive::<Float64Type>();
194            let array = array.iter().collect::<Vec<_>>();
195            builder.encode_field(&array)
196        }
197        DataType::Binary => {
198            let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
199
200            let array = array.as_binary::<i32>();
201            match *bytea_output {
202                PGByteaOutputValue::ESCAPE => {
203                    let array = array
204                        .iter()
205                        .map(|v| v.map(EscapeOutputBytea))
206                        .collect::<Vec<_>>();
207                    builder.encode_field(&array)
208                }
209                PGByteaOutputValue::HEX => {
210                    let array = array
211                        .iter()
212                        .map(|v| v.map(HexOutputBytea))
213                        .collect::<Vec<_>>();
214                    builder.encode_field(&array)
215                }
216            }
217        }
218        DataType::Utf8 => {
219            let array = array.as_string::<i32>();
220            let array = array.into_iter().collect::<Vec<_>>();
221            builder.encode_field(&array)
222        }
223        DataType::LargeUtf8 => {
224            let array = array.as_string::<i64>();
225            let array = array.into_iter().collect::<Vec<_>>();
226            builder.encode_field(&array)
227        }
228        DataType::Utf8View => {
229            let array = array.as_string_view();
230            let array = array.into_iter().collect::<Vec<_>>();
231            builder.encode_field(&array)
232        }
233        DataType::Date32 | DataType::Date64 => {
234            let iter: Box<dyn Iterator<Item = Option<i32>>> =
235                if matches!(array.data_type(), DataType::Date32) {
236                    let array = array.as_primitive::<Date32Type>();
237                    Box::new(array.into_iter())
238                } else {
239                    let array = array.as_primitive::<Date64Type>();
240                    // `Date64` values are milliseconds representation of `Date32` values, according
241                    // to its specification. So we convert them to `Date32` values to process the
242                    // `Date64` array unified with `Date32` array.
243                    Box::new(
244                        array
245                            .into_iter()
246                            .map(|x| x.map(|i| (i / 86_400_000) as i32)),
247                    )
248                };
249            let array = iter
250                .into_iter()
251                .map(|v| match v {
252                    None => Ok(None),
253                    Some(v) => {
254                        if let Some(date) = Date::new(v).to_chrono_date() {
255                            let (style, order) =
256                                *query_ctx.configuration_parameter().pg_datetime_style();
257                            Ok(Some(StylingDate(date, style, order)))
258                        } else {
259                            Err(convert_err(Error::Internal {
260                                err_msg: format!("Failed to convert date to postgres type {v:?}",),
261                            }))
262                        }
263                    }
264                })
265                .collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
266            builder.encode_field(&array)
267        }
268        DataType::Timestamp(time_unit, _) => {
269            let array = match time_unit {
270                TimeUnit::Second => {
271                    let array = array.as_primitive::<TimestampSecondType>();
272                    array.into_iter().collect::<Vec<_>>()
273                }
274                TimeUnit::Millisecond => {
275                    let array = array.as_primitive::<TimestampMillisecondType>();
276                    array.into_iter().collect::<Vec<_>>()
277                }
278                TimeUnit::Microsecond => {
279                    let array = array.as_primitive::<TimestampMicrosecondType>();
280                    array.into_iter().collect::<Vec<_>>()
281                }
282                TimeUnit::Nanosecond => {
283                    let array = array.as_primitive::<TimestampNanosecondType>();
284                    array.into_iter().collect::<Vec<_>>()
285                }
286            };
287            let time_unit = time_unit.into();
288            let array = array
289                .into_iter()
290                .map(|v| match v {
291                    None => Ok(None),
292                    Some(v) => {
293                        let v = Timestamp::new(v, time_unit);
294                        if let Some(datetime) =
295                            v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
296                        {
297                            let (style, order) =
298                                *query_ctx.configuration_parameter().pg_datetime_style();
299                            Ok(Some(StylingDateTime(datetime, style, order)))
300                        } else {
301                            Err(convert_err(Error::Internal {
302                                err_msg: format!("Failed to convert date to postgres type {v:?}",),
303                            }))
304                        }
305                    }
306                })
307                .collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
308            builder.encode_field(&array)
309        }
310        DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
311            let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
312                TimeUnit::Second => {
313                    let array = array.as_primitive::<Time32SecondType>();
314                    Box::new(
315                        array
316                            .into_iter()
317                            .map(|v| v.map(|i| Time::new_second(i as i64))),
318                    )
319                }
320                TimeUnit::Millisecond => {
321                    let array = array.as_primitive::<Time32MillisecondType>();
322                    Box::new(
323                        array
324                            .into_iter()
325                            .map(|v| v.map(|i| Time::new_millisecond(i as i64))),
326                    )
327                }
328                TimeUnit::Microsecond => {
329                    let array = array.as_primitive::<Time64MicrosecondType>();
330                    Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
331                }
332                TimeUnit::Nanosecond => {
333                    let array = array.as_primitive::<Time64NanosecondType>();
334                    Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
335                }
336            };
337            let array = iter
338                .into_iter()
339                .map(|v| v.and_then(|v| v.to_chrono_time()))
340                .collect::<Vec<Option<NaiveTime>>>();
341            builder.encode_field(&array)
342        }
343        DataType::Interval(interval_unit) => {
344            let array = match interval_unit {
345                IntervalUnit::YearMonth => {
346                    let array = array.as_primitive::<IntervalYearMonthType>();
347                    array
348                        .into_iter()
349                        .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
350                        .collect::<Vec<_>>()
351                }
352                IntervalUnit::DayTime => {
353                    let array = array.as_primitive::<IntervalDayTimeType>();
354                    array
355                        .into_iter()
356                        .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
357                        .collect::<Vec<_>>()
358                }
359                IntervalUnit::MonthDayNano => {
360                    let array = array.as_primitive::<IntervalMonthDayNanoType>();
361                    array
362                        .into_iter()
363                        .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
364                        .collect::<Vec<_>>()
365                }
366            };
367            builder.encode_field(&array)
368        }
369        DataType::Decimal128(precision, scale) => {
370            let array = array.as_primitive::<Decimal128Type>();
371            let array = array
372                .into_iter()
373                .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
374                .collect::<Vec<_>>();
375            builder.encode_field(&array)
376        }
377        _ => Err(convert_err(Error::Internal {
378            err_msg: format!(
379                "cannot write array type {:?} in postgres protocol: unimplemented",
380                array.data_type()
381            ),
382        })),
383    }
384}
385
386pub(crate) struct RecordBatchRowIterator {
387    query_ctx: QueryContextRef,
388    pg_schema: Arc<Vec<FieldInfo>>,
389    schema: SchemaRef,
390    record_batch: arrow::record_batch::RecordBatch,
391    i: usize,
392}
393
394impl Iterator for RecordBatchRowIterator {
395    type Item = PgWireResult<DataRow>;
396
397    fn next(&mut self) -> Option<Self::Item> {
398        if self.i < self.record_batch.num_rows() {
399            let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
400            if let Err(e) = self.encode_row(self.i, &mut encoder) {
401                return Some(Err(e));
402            }
403            self.i += 1;
404            Some(encoder.finish())
405        } else {
406            None
407        }
408    }
409}
410
411impl RecordBatchRowIterator {
412    pub(crate) fn new(
413        query_ctx: QueryContextRef,
414        pg_schema: Arc<Vec<FieldInfo>>,
415        record_batch: RecordBatch,
416    ) -> Self {
417        let schema = record_batch.schema.clone();
418        let record_batch = record_batch.into_df_record_batch();
419        Self {
420            query_ctx,
421            pg_schema,
422            schema,
423            record_batch,
424            i: 0,
425        }
426    }
427
428    fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
429        for (j, column) in self.record_batch.columns().iter().enumerate() {
430            if column.is_null(i) {
431                encoder.encode_field(&None::<&i8>)?;
432                continue;
433            }
434
435            match column.data_type() {
436                DataType::Null => {
437                    encoder.encode_field(&None::<&i8>)?;
438                }
439                DataType::Boolean => {
440                    let array = column.as_boolean();
441                    encoder.encode_field(&array.value(i))?;
442                }
443                DataType::UInt8 => {
444                    let array = column.as_primitive::<UInt8Type>();
445                    let value = array.value(i);
446                    if value <= i8::MAX as u8 {
447                        encoder.encode_field(&(value as i8))?;
448                    } else {
449                        encoder.encode_field(&(value as i16))?;
450                    }
451                }
452                DataType::UInt16 => {
453                    let array = column.as_primitive::<UInt16Type>();
454                    let value = array.value(i);
455                    if value <= i16::MAX as u16 {
456                        encoder.encode_field(&(value as i16))?;
457                    } else {
458                        encoder.encode_field(&(value as i32))?;
459                    }
460                }
461                DataType::UInt32 => {
462                    let array = column.as_primitive::<UInt32Type>();
463                    let value = array.value(i);
464                    if value <= i32::MAX as u32 {
465                        encoder.encode_field(&(value as i32))?;
466                    } else {
467                        encoder.encode_field(&(value as i64))?;
468                    }
469                }
470                DataType::UInt64 => {
471                    let array = column.as_primitive::<UInt64Type>();
472                    let value = array.value(i);
473                    if value <= i64::MAX as u64 {
474                        encoder.encode_field(&(value as i64))?;
475                    } else {
476                        encoder.encode_field(&value.to_string())?;
477                    }
478                }
479                DataType::Int8 => {
480                    let array = column.as_primitive::<Int8Type>();
481                    encoder.encode_field(&array.value(i))?;
482                }
483                DataType::Int16 => {
484                    let array = column.as_primitive::<Int16Type>();
485                    encoder.encode_field(&array.value(i))?;
486                }
487                DataType::Int32 => {
488                    let array = column.as_primitive::<Int32Type>();
489                    encoder.encode_field(&array.value(i))?;
490                }
491                DataType::Int64 => {
492                    let array = column.as_primitive::<Int64Type>();
493                    encoder.encode_field(&array.value(i))?;
494                }
495                DataType::Float32 => {
496                    let array = column.as_primitive::<Float32Type>();
497                    encoder.encode_field(&array.value(i))?;
498                }
499                DataType::Float64 => {
500                    let array = column.as_primitive::<Float64Type>();
501                    encoder.encode_field(&array.value(i))?;
502                }
503                DataType::Utf8 => {
504                    let array = column.as_string::<i32>();
505                    let value = array.value(i);
506                    encoder.encode_field(&value)?;
507                }
508                DataType::Utf8View => {
509                    let array = column.as_string_view();
510                    let value = array.value(i);
511                    encoder.encode_field(&value)?;
512                }
513                DataType::LargeUtf8 => {
514                    let array = column.as_string::<i64>();
515                    let value = array.value(i);
516                    encoder.encode_field(&value)?;
517                }
518                DataType::Binary => {
519                    let array = column.as_binary::<i32>();
520                    let v = array.value(i);
521                    encode_bytes(
522                        &self.schema.column_schemas()[j],
523                        v,
524                        encoder,
525                        &self.query_ctx,
526                    )?;
527                }
528                DataType::BinaryView => {
529                    let array = column.as_binary_view();
530                    let v = array.value(i);
531                    encode_bytes(
532                        &self.schema.column_schemas()[j],
533                        v,
534                        encoder,
535                        &self.query_ctx,
536                    )?;
537                }
538                DataType::LargeBinary => {
539                    let array = column.as_binary::<i64>();
540                    let v = array.value(i);
541                    encode_bytes(
542                        &self.schema.column_schemas()[j],
543                        v,
544                        encoder,
545                        &self.query_ctx,
546                    )?;
547                }
548                DataType::Date32 | DataType::Date64 => {
549                    let v = if matches!(column.data_type(), DataType::Date32) {
550                        let array = column.as_primitive::<Date32Type>();
551                        array.value(i)
552                    } else {
553                        let array = column.as_primitive::<Date64Type>();
554                        // `Date64` values are milliseconds representation of `Date32` values,
555                        // according to its specification. So we convert the `Date64` value here to
556                        // the `Date32` value to process them unified.
557                        (array.value(i) / 86_400_000) as i32
558                    };
559                    let v = Date::new(v);
560                    let date = v.to_chrono_date().map(|v| {
561                        let (style, order) =
562                            *self.query_ctx.configuration_parameter().pg_datetime_style();
563                        StylingDate(v, style, order)
564                    });
565                    encoder.encode_field(&date)?;
566                }
567                DataType::Timestamp(_, _) => {
568                    let v = datatypes::arrow_array::timestamp_array_value(column, i);
569                    let datetime = v
570                        .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
571                        .map(|v| {
572                            let (style, order) =
573                                *self.query_ctx.configuration_parameter().pg_datetime_style();
574                            StylingDateTime(v, style, order)
575                        });
576                    encoder.encode_field(&datetime)?;
577                }
578                DataType::Interval(interval_unit) => match interval_unit {
579                    IntervalUnit::YearMonth => {
580                        let array = column.as_primitive::<IntervalYearMonthType>();
581                        let v: IntervalYearMonth = array.value(i).into();
582                        encoder.encode_field(&PgInterval::from(v))?;
583                    }
584                    IntervalUnit::DayTime => {
585                        let array = column.as_primitive::<IntervalDayTimeType>();
586                        let v: IntervalDayTime = array.value(i).into();
587                        encoder.encode_field(&PgInterval::from(v))?;
588                    }
589                    IntervalUnit::MonthDayNano => {
590                        let array = column.as_primitive::<IntervalMonthDayNanoType>();
591                        let v: IntervalMonthDayNano = array.value(i).into();
592                        encoder.encode_field(&PgInterval::from(v))?;
593                    }
594                },
595                DataType::Duration(_) => {
596                    let d = datatypes::arrow_array::duration_array_value(column, i);
597                    match PgInterval::try_from(d) {
598                        Ok(i) => encoder.encode_field(&i)?,
599                        Err(e) => {
600                            return Err(convert_err(Error::Internal {
601                                err_msg: e.to_string(),
602                            }));
603                        }
604                    }
605                }
606                DataType::List(_) => {
607                    let array = column.as_list::<i32>();
608                    let items = array.value(i);
609                    encode_array(&self.query_ctx, items, encoder)?;
610                }
611                DataType::Struct(_) => {
612                    encode_struct(&self.query_ctx, Default::default(), encoder)?;
613                }
614                DataType::Time32(_) | DataType::Time64(_) => {
615                    let v = datatypes::arrow_array::time_array_value(column, i);
616                    encoder.encode_field(&v.to_chrono_time())?;
617                }
618                DataType::Decimal128(precision, scale) => {
619                    let array = column.as_primitive::<Decimal128Type>();
620                    let v = Decimal128::new(array.value(i), *precision, *scale);
621                    encoder.encode_field(&v.to_string())?;
622                }
623                _ => {
624                    return Err(convert_err(Error::Internal {
625                        err_msg: format!(
626                            "cannot convert datatype {} to postgres",
627                            column.data_type()
628                        ),
629                    }));
630                }
631            }
632        }
633        Ok(())
634    }
635}
636
637fn encode_bytes(
638    schema: &ColumnSchema,
639    v: &[u8],
640    encoder: &mut DataRowEncoder,
641    query_ctx: &QueryContextRef,
642) -> PgWireResult<()> {
643    if let ConcreteDataType::Json(_) = &schema.data_type {
644        let s = jsonb_to_string(v).map_err(convert_err)?;
645        encoder.encode_field(&s)
646    } else {
647        let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
648        match *bytea_output {
649            PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
650            PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
651        }
652    }
653}
654
655pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
656    match origin {
657        &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
658        &ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
659        &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
660        &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
661        &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
662        &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
663        &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
664        &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
665        &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
666        &ConcreteDataType::String(_) => Ok(Type::VARCHAR),
667        &ConcreteDataType::Date(_) => Ok(Type::DATE),
668        &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
669        &ConcreteDataType::Time(_) => Ok(Type::TIME),
670        &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL),
671        &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
672        &ConcreteDataType::Json(_) => Ok(Type::JSON),
673        ConcreteDataType::List(list) => match list.item_type() {
674            &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
675            &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
676            &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
677            &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
678            &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
679            &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
680            &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
681            &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
682            &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
683            &ConcreteDataType::String(_) => Ok(Type::VARCHAR_ARRAY),
684            &ConcreteDataType::Date(_) => Ok(Type::DATE_ARRAY),
685            &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP_ARRAY),
686            &ConcreteDataType::Time(_) => Ok(Type::TIME_ARRAY),
687            &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL_ARRAY),
688            &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC_ARRAY),
689            &ConcreteDataType::Json(_) => Ok(Type::JSON_ARRAY),
690            &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL_ARRAY),
691            &ConcreteDataType::Struct(_) => Ok(Type::JSON_ARRAY),
692            &ConcreteDataType::Dictionary(_)
693            | &ConcreteDataType::Vector(_)
694            | &ConcreteDataType::List(_) => server_error::UnsupportedDataTypeSnafu {
695                data_type: origin,
696                reason: "not implemented",
697            }
698            .fail(),
699        },
700        &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
701            data_type: origin,
702            reason: "not implemented",
703        }
704        .fail(),
705        &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL),
706        &ConcreteDataType::Struct(_) => Ok(Type::JSON),
707    }
708}
709
710#[allow(dead_code)]
711pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
712    // Note that we only support a small amount of pg data types
713    match origin {
714        &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
715        &Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
716        &Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
717        &Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
718        &Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
719        &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
720        &Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype(
721            common_time::timestamp::TimeUnit::Millisecond,
722        )),
723        &Type::DATE => Ok(ConcreteDataType::date_datatype()),
724        &Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
725            common_time::timestamp::TimeUnit::Microsecond,
726        )),
727        &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
728            ConcreteDataType::int8_datatype(),
729        ))),
730        &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
731            ConcreteDataType::int16_datatype(),
732        ))),
733        &Type::INT4_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
734            ConcreteDataType::int32_datatype(),
735        ))),
736        &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
737            ConcreteDataType::int64_datatype(),
738        ))),
739        &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
740            ConcreteDataType::string_datatype(),
741        ))),
742        _ => server_error::InternalSnafu {
743            err_msg: format!("unimplemented datatype {origin:?}"),
744        }
745        .fail(),
746    }
747}
748
749pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
750    // the index is managed from portal's parameters count so it's safe to
751    // unwrap here.
752    let param_type = portal
753        .statement
754        .parameter_types
755        .get(idx)
756        .unwrap()
757        .as_ref()
758        .unwrap_or(&Type::UNKNOWN);
759    match param_type {
760        &Type::VARCHAR | &Type::TEXT => Ok(format!(
761            "'{}'",
762            portal
763                .parameter::<String>(idx, param_type)?
764                .as_deref()
765                .unwrap_or("")
766        )),
767        &Type::BOOL => Ok(portal
768            .parameter::<bool>(idx, param_type)?
769            .map(|v| v.to_string())
770            .unwrap_or_else(|| "".to_owned())),
771        &Type::INT4 => Ok(portal
772            .parameter::<i32>(idx, param_type)?
773            .map(|v| v.to_string())
774            .unwrap_or_else(|| "".to_owned())),
775        &Type::INT8 => Ok(portal
776            .parameter::<i64>(idx, param_type)?
777            .map(|v| v.to_string())
778            .unwrap_or_else(|| "".to_owned())),
779        &Type::FLOAT4 => Ok(portal
780            .parameter::<f32>(idx, param_type)?
781            .map(|v| v.to_string())
782            .unwrap_or_else(|| "".to_owned())),
783        &Type::FLOAT8 => Ok(portal
784            .parameter::<f64>(idx, param_type)?
785            .map(|v| v.to_string())
786            .unwrap_or_else(|| "".to_owned())),
787        &Type::DATE => Ok(portal
788            .parameter::<NaiveDate>(idx, param_type)?
789            .map(|v| v.format("%Y-%m-%d").to_string())
790            .unwrap_or_else(|| "".to_owned())),
791        &Type::TIMESTAMP => Ok(portal
792            .parameter::<NaiveDateTime>(idx, param_type)?
793            .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
794            .unwrap_or_else(|| "".to_owned())),
795        &Type::INTERVAL => Ok(portal
796            .parameter::<PgInterval>(idx, param_type)?
797            .map(|v| v.to_string())
798            .unwrap_or_else(|| "".to_owned())),
799        _ => Err(invalid_parameter_error(
800            "unsupported_parameter_type",
801            Some(param_type.to_string()),
802        )),
803    }
804}
805
806pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
807    let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
808    error_info.detail = detail;
809    PgWireError::UserError(Box::new(error_info))
810}
811
812fn to_timestamp_scalar_value<T>(
813    data: Option<T>,
814    unit: &TimestampType,
815    ctype: &ConcreteDataType,
816) -> PgWireResult<ScalarValue>
817where
818    T: Into<i64>,
819{
820    if let Some(n) = data {
821        Value::Timestamp(unit.create_timestamp(n.into()))
822            .try_to_scalar_value(ctype)
823            .map_err(convert_err)
824    } else {
825        Ok(ScalarValue::Null)
826    }
827}
828
829pub(super) fn parameters_to_scalar_values(
830    plan: &LogicalPlan,
831    portal: &Portal<SqlPlan>,
832) -> PgWireResult<Vec<ScalarValue>> {
833    let param_count = portal.parameter_len();
834    let mut results = Vec::with_capacity(param_count);
835
836    let client_param_types = &portal.statement.parameter_types;
837    let server_param_types = plan
838        .get_parameter_types()
839        .context(DataFusionSnafu)
840        .map_err(convert_err)?
841        .into_iter()
842        .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
843        .collect::<HashMap<_, _>>();
844
845    for idx in 0..param_count {
846        let server_type = server_param_types
847            .get(&format!("${}", idx + 1))
848            .and_then(|t| t.as_ref());
849
850        let client_type = if let Some(Some(client_given_type)) = client_param_types.get(idx) {
851            client_given_type.clone()
852        } else if let Some(server_provided_type) = &server_type {
853            type_gt_to_pg(server_provided_type).map_err(convert_err)?
854        } else {
855            return Err(invalid_parameter_error(
856                "unknown_parameter_type",
857                Some(format!(
858                    "Cannot get parameter type information for parameter {}",
859                    idx
860                )),
861            ));
862        };
863
864        let value = match &client_type {
865            &Type::VARCHAR | &Type::TEXT => {
866                let data = portal.parameter::<String>(idx, &client_type)?;
867                if let Some(server_type) = &server_type {
868                    match server_type {
869                        ConcreteDataType::String(t) => {
870                            if t.is_large() {
871                                ScalarValue::LargeUtf8(data)
872                            } else {
873                                ScalarValue::Utf8(data)
874                            }
875                        }
876                        _ => {
877                            return Err(invalid_parameter_error(
878                                "invalid_parameter_type",
879                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
880                            ));
881                        }
882                    }
883                } else {
884                    ScalarValue::Utf8(data)
885                }
886            }
887            &Type::BOOL => {
888                let data = portal.parameter::<bool>(idx, &client_type)?;
889                if let Some(server_type) = &server_type {
890                    match server_type {
891                        ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
892                        _ => {
893                            return Err(invalid_parameter_error(
894                                "invalid_parameter_type",
895                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
896                            ));
897                        }
898                    }
899                } else {
900                    ScalarValue::Boolean(data)
901                }
902            }
903            &Type::INT2 => {
904                let data = portal.parameter::<i16>(idx, &client_type)?;
905                if let Some(server_type) = &server_type {
906                    match server_type {
907                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
908                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
909                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
910                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
911                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
912                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
913                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
914                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
915                        ConcreteDataType::Timestamp(unit) => {
916                            to_timestamp_scalar_value(data, unit, server_type)?
917                        }
918                        _ => {
919                            return Err(invalid_parameter_error(
920                                "invalid_parameter_type",
921                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
922                            ));
923                        }
924                    }
925                } else {
926                    ScalarValue::Int16(data)
927                }
928            }
929            &Type::INT4 => {
930                let data = portal.parameter::<i32>(idx, &client_type)?;
931                if let Some(server_type) = &server_type {
932                    match server_type {
933                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
934                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
935                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
936                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
937                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
938                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
939                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
940                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
941                        ConcreteDataType::Timestamp(unit) => {
942                            to_timestamp_scalar_value(data, unit, server_type)?
943                        }
944                        _ => {
945                            return Err(invalid_parameter_error(
946                                "invalid_parameter_type",
947                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
948                            ));
949                        }
950                    }
951                } else {
952                    ScalarValue::Int32(data)
953                }
954            }
955            &Type::INT8 => {
956                let data = portal.parameter::<i64>(idx, &client_type)?;
957                if let Some(server_type) = &server_type {
958                    match server_type {
959                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
960                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
961                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
962                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
963                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
964                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
965                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
966                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
967                        ConcreteDataType::Timestamp(unit) => {
968                            to_timestamp_scalar_value(data, unit, server_type)?
969                        }
970                        _ => {
971                            return Err(invalid_parameter_error(
972                                "invalid_parameter_type",
973                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
974                            ));
975                        }
976                    }
977                } else {
978                    ScalarValue::Int64(data)
979                }
980            }
981            &Type::FLOAT4 => {
982                let data = portal.parameter::<f32>(idx, &client_type)?;
983                if let Some(server_type) = &server_type {
984                    match server_type {
985                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
986                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
987                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
988                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
989                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
990                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
991                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
992                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
993                        ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
994                        ConcreteDataType::Float64(_) => {
995                            ScalarValue::Float64(data.map(|n| n as f64))
996                        }
997                        _ => {
998                            return Err(invalid_parameter_error(
999                                "invalid_parameter_type",
1000                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1001                            ));
1002                        }
1003                    }
1004                } else {
1005                    ScalarValue::Float32(data)
1006                }
1007            }
1008            &Type::FLOAT8 => {
1009                let data = portal.parameter::<f64>(idx, &client_type)?;
1010                if let Some(server_type) = &server_type {
1011                    match server_type {
1012                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1013                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1014                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1015                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1016                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1017                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1018                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1019                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1020                        ConcreteDataType::Float32(_) => {
1021                            ScalarValue::Float32(data.map(|n| n as f32))
1022                        }
1023                        ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
1024                        _ => {
1025                            return Err(invalid_parameter_error(
1026                                "invalid_parameter_type",
1027                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1028                            ));
1029                        }
1030                    }
1031                } else {
1032                    ScalarValue::Float64(data)
1033                }
1034            }
1035            &Type::TIMESTAMP => {
1036                let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
1037                if let Some(server_type) = &server_type {
1038                    match server_type {
1039                        ConcreteDataType::Timestamp(unit) => match *unit {
1040                            TimestampType::Second(_) => ScalarValue::TimestampSecond(
1041                                data.map(|ts| ts.and_utc().timestamp()),
1042                                None,
1043                            ),
1044                            TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1045                                data.map(|ts| ts.and_utc().timestamp_millis()),
1046                                None,
1047                            ),
1048                            TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1049                                data.map(|ts| ts.and_utc().timestamp_micros()),
1050                                None,
1051                            ),
1052                            TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1053                                data.and_then(|ts| ts.and_utc().timestamp_nanos_opt()),
1054                                None,
1055                            ),
1056                        },
1057                        _ => {
1058                            return Err(invalid_parameter_error(
1059                                "invalid_parameter_type",
1060                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1061                            ));
1062                        }
1063                    }
1064                } else {
1065                    ScalarValue::TimestampMillisecond(
1066                        data.map(|ts| ts.and_utc().timestamp_millis()),
1067                        None,
1068                    )
1069                }
1070            }
1071            &Type::TIMESTAMPTZ => {
1072                let data = portal.parameter::<DateTime<FixedOffset>>(idx, &client_type)?;
1073                if let Some(server_type) = &server_type {
1074                    match server_type {
1075                        ConcreteDataType::Timestamp(unit) => match *unit {
1076                            TimestampType::Second(_) => {
1077                                ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None)
1078                            }
1079                            TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1080                                data.map(|ts| ts.timestamp_millis()),
1081                                None,
1082                            ),
1083                            TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1084                                data.map(|ts| ts.timestamp_micros()),
1085                                None,
1086                            ),
1087                            TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1088                                data.and_then(|ts| ts.timestamp_nanos_opt()),
1089                                None,
1090                            ),
1091                        },
1092                        _ => {
1093                            return Err(invalid_parameter_error(
1094                                "invalid_parameter_type",
1095                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1096                            ));
1097                        }
1098                    }
1099                } else {
1100                    ScalarValue::TimestampMillisecond(data.map(|ts| ts.timestamp_millis()), None)
1101                }
1102            }
1103            &Type::DATE => {
1104                let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
1105                if let Some(server_type) = &server_type {
1106                    match server_type {
1107                        ConcreteDataType::Date(_) => ScalarValue::Date32(
1108                            data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1109                        ),
1110                        _ => {
1111                            return Err(invalid_parameter_error(
1112                                "invalid_parameter_type",
1113                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1114                            ));
1115                        }
1116                    }
1117                } else {
1118                    ScalarValue::Date32(
1119                        data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1120                    )
1121                }
1122            }
1123            &Type::INTERVAL => {
1124                let data = portal.parameter::<PgInterval>(idx, &client_type)?;
1125                if let Some(server_type) = &server_type {
1126                    match server_type {
1127                        ConcreteDataType::Interval(IntervalType::YearMonth(_)) => {
1128                            ScalarValue::IntervalYearMonth(
1129                                data.map(|i| {
1130                                    if i.days != 0 || i.microseconds != 0 {
1131                                        Err(invalid_parameter_error(
1132                                            "invalid_parameter_type",
1133                                            Some(format!(
1134                                                "Expected: {}, found: {}",
1135                                                server_type, client_type
1136                                            )),
1137                                        ))
1138                                    } else {
1139                                        Ok(IntervalYearMonth::new(i.months).to_i32())
1140                                    }
1141                                })
1142                                .transpose()?,
1143                            )
1144                        }
1145                        ConcreteDataType::Interval(IntervalType::DayTime(_)) => {
1146                            ScalarValue::IntervalDayTime(
1147                                data.map(|i| {
1148                                    if i.months != 0 || i.microseconds % 1000 != 0 {
1149                                        Err(invalid_parameter_error(
1150                                            "invalid_parameter_type",
1151                                            Some(format!(
1152                                                "Expected: {}, found: {}",
1153                                                server_type, client_type
1154                                            )),
1155                                        ))
1156                                    } else {
1157                                        Ok(IntervalDayTime::new(
1158                                            i.days,
1159                                            (i.microseconds / 1000) as i32,
1160                                        )
1161                                        .into())
1162                                    }
1163                                })
1164                                .transpose()?,
1165                            )
1166                        }
1167                        ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
1168                            ScalarValue::IntervalMonthDayNano(
1169                                data.map(|i| IntervalMonthDayNano::from(i).into()),
1170                            )
1171                        }
1172                        _ => {
1173                            return Err(invalid_parameter_error(
1174                                "invalid_parameter_type",
1175                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1176                            ));
1177                        }
1178                    }
1179                } else {
1180                    ScalarValue::IntervalMonthDayNano(
1181                        data.map(|i| IntervalMonthDayNano::from(i).into()),
1182                    )
1183                }
1184            }
1185            &Type::BYTEA => {
1186                let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
1187                if let Some(server_type) = &server_type {
1188                    match server_type {
1189                        ConcreteDataType::String(t) => {
1190                            let s = data.map(|d| String::from_utf8_lossy(&d).to_string());
1191                            if t.is_large() {
1192                                ScalarValue::LargeUtf8(s)
1193                            } else {
1194                                ScalarValue::Utf8(s)
1195                            }
1196                        }
1197                        ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
1198                        _ => {
1199                            return Err(invalid_parameter_error(
1200                                "invalid_parameter_type",
1201                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1202                            ));
1203                        }
1204                    }
1205                } else {
1206                    ScalarValue::Binary(data)
1207                }
1208            }
1209            &Type::JSONB => {
1210                let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
1211                if let Some(server_type) = &server_type {
1212                    match server_type {
1213                        ConcreteDataType::Binary(_) => {
1214                            ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1215                        }
1216                        _ => {
1217                            return Err(invalid_parameter_error(
1218                                "invalid_parameter_type",
1219                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1220                            ));
1221                        }
1222                    }
1223                } else {
1224                    ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1225                }
1226            }
1227            &Type::INT2_ARRAY => {
1228                let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
1229                if let Some(data) = data {
1230                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1231                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
1232                } else {
1233                    ScalarValue::Null
1234                }
1235            }
1236            &Type::INT4_ARRAY => {
1237                let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
1238                if let Some(data) = data {
1239                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1240                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
1241                } else {
1242                    ScalarValue::Null
1243                }
1244            }
1245            &Type::INT8_ARRAY => {
1246                let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
1247                if let Some(data) = data {
1248                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1249                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
1250                } else {
1251                    ScalarValue::Null
1252                }
1253            }
1254            &Type::VARCHAR_ARRAY => {
1255                let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
1256                if let Some(data) = data {
1257                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1258                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
1259                } else {
1260                    ScalarValue::Null
1261                }
1262            }
1263            _ => Err(invalid_parameter_error(
1264                "unsupported_parameter_value",
1265                Some(format!("Found type: {}", client_type)),
1266            ))?,
1267        };
1268
1269        results.push(value);
1270    }
1271
1272    Ok(results)
1273}
1274
1275pub(super) fn param_types_to_pg_types(
1276    param_types: &HashMap<String, Option<ConcreteDataType>>,
1277) -> Result<Vec<Type>> {
1278    let param_count = param_types.len();
1279    let mut types = Vec::with_capacity(param_count);
1280    for i in 0..param_count {
1281        if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
1282            let pg_type = type_gt_to_pg(param_type)?;
1283            types.push(pg_type);
1284        } else {
1285            types.push(Type::UNKNOWN);
1286        }
1287    }
1288    Ok(types)
1289}
1290
1291#[cfg(test)]
1292mod test {
1293    use std::sync::Arc;
1294
1295    use arrow::array::{
1296        Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
1297    };
1298    use arrow_schema::Field;
1299    use datatypes::schema::{ColumnSchema, Schema};
1300    use datatypes::vectors::{
1301        BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
1302        Int16Vector, Int32Vector, Int64Vector, IntervalDayTimeVector, IntervalMonthDayNanoVector,
1303        IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
1304        TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
1305    };
1306    use pgwire::api::Type;
1307    use pgwire::api::results::{FieldFormat, FieldInfo};
1308    use session::context::QueryContextBuilder;
1309
1310    use super::*;
1311
1312    #[test]
1313    fn test_schema_convert() {
1314        let column_schemas = vec![
1315            ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
1316            ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
1317            ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
1318            ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
1319            ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
1320            ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
1321            ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
1322            ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
1323            ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
1324            ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
1325            ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
1326            ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
1327            ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
1328            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1329            ColumnSchema::new(
1330                "timestamps",
1331                ConcreteDataType::timestamp_millisecond_datatype(),
1332                true,
1333            ),
1334            ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
1335            ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
1336            ColumnSchema::new(
1337                "intervals",
1338                ConcreteDataType::interval_month_day_nano_datatype(),
1339                true,
1340            ),
1341        ];
1342        let pg_field_info = vec![
1343            FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1344            FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1345            FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1346            FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1347            FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1348            FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1349            FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1350            FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1351            FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1352            FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1353            FieldInfo::new(
1354                "float32s".into(),
1355                None,
1356                None,
1357                Type::FLOAT4,
1358                FieldFormat::Text,
1359            ),
1360            FieldInfo::new(
1361                "float64s".into(),
1362                None,
1363                None,
1364                Type::FLOAT8,
1365                FieldFormat::Text,
1366            ),
1367            FieldInfo::new(
1368                "binaries".into(),
1369                None,
1370                None,
1371                Type::BYTEA,
1372                FieldFormat::Text,
1373            ),
1374            FieldInfo::new(
1375                "strings".into(),
1376                None,
1377                None,
1378                Type::VARCHAR,
1379                FieldFormat::Text,
1380            ),
1381            FieldInfo::new(
1382                "timestamps".into(),
1383                None,
1384                None,
1385                Type::TIMESTAMP,
1386                FieldFormat::Text,
1387            ),
1388            FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1389            FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1390            FieldInfo::new(
1391                "intervals".into(),
1392                None,
1393                None,
1394                Type::INTERVAL,
1395                FieldFormat::Text,
1396            ),
1397        ];
1398        let schema = Schema::new(column_schemas);
1399        let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
1400        assert_eq!(fs, pg_field_info);
1401    }
1402
1403    #[test]
1404    fn test_encode_text_format_data() {
1405        let schema = vec![
1406            FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1407            FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1408            FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1409            FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1410            FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1411            FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1412            FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1413            FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1414            FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1415            FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1416            FieldInfo::new(
1417                "float32s".into(),
1418                None,
1419                None,
1420                Type::FLOAT4,
1421                FieldFormat::Text,
1422            ),
1423            FieldInfo::new(
1424                "float64s".into(),
1425                None,
1426                None,
1427                Type::FLOAT8,
1428                FieldFormat::Text,
1429            ),
1430            FieldInfo::new(
1431                "strings".into(),
1432                None,
1433                None,
1434                Type::VARCHAR,
1435                FieldFormat::Text,
1436            ),
1437            FieldInfo::new(
1438                "binaries".into(),
1439                None,
1440                None,
1441                Type::BYTEA,
1442                FieldFormat::Text,
1443            ),
1444            FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1445            FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1446            FieldInfo::new(
1447                "timestamps".into(),
1448                None,
1449                None,
1450                Type::TIMESTAMP,
1451                FieldFormat::Text,
1452            ),
1453            FieldInfo::new(
1454                "interval_year_month".into(),
1455                None,
1456                None,
1457                Type::INTERVAL,
1458                FieldFormat::Text,
1459            ),
1460            FieldInfo::new(
1461                "interval_day_time".into(),
1462                None,
1463                None,
1464                Type::INTERVAL,
1465                FieldFormat::Text,
1466            ),
1467            FieldInfo::new(
1468                "interval_month_day_nano".into(),
1469                None,
1470                None,
1471                Type::INTERVAL,
1472                FieldFormat::Text,
1473            ),
1474            FieldInfo::new(
1475                "int_list".into(),
1476                None,
1477                None,
1478                Type::INT8_ARRAY,
1479                FieldFormat::Text,
1480            ),
1481            FieldInfo::new(
1482                "float_list".into(),
1483                None,
1484                None,
1485                Type::FLOAT8_ARRAY,
1486                FieldFormat::Text,
1487            ),
1488            FieldInfo::new(
1489                "string_list".into(),
1490                None,
1491                None,
1492                Type::VARCHAR_ARRAY,
1493                FieldFormat::Text,
1494            ),
1495            FieldInfo::new(
1496                "timestamp_list".into(),
1497                None,
1498                None,
1499                Type::TIMESTAMP_ARRAY,
1500                FieldFormat::Text,
1501            ),
1502        ];
1503
1504        let arrow_schema = arrow_schema::Schema::new(vec![
1505            Field::new("x", DataType::Null, true),
1506            Field::new("x", DataType::Boolean, true),
1507            Field::new("x", DataType::UInt8, true),
1508            Field::new("x", DataType::UInt16, true),
1509            Field::new("x", DataType::UInt32, true),
1510            Field::new("x", DataType::UInt64, true),
1511            Field::new("x", DataType::Int8, true),
1512            Field::new("x", DataType::Int16, true),
1513            Field::new("x", DataType::Int32, true),
1514            Field::new("x", DataType::Int64, true),
1515            Field::new("x", DataType::Float32, true),
1516            Field::new("x", DataType::Float64, true),
1517            Field::new("x", DataType::Utf8, true),
1518            Field::new("x", DataType::Binary, true),
1519            Field::new("x", DataType::Date32, true),
1520            Field::new("x", DataType::Time32(TimeUnit::Second), true),
1521            Field::new("x", DataType::Timestamp(TimeUnit::Second, None), true),
1522            Field::new("x", DataType::Interval(IntervalUnit::YearMonth), true),
1523            Field::new("x", DataType::Interval(IntervalUnit::DayTime), true),
1524            Field::new("x", DataType::Interval(IntervalUnit::MonthDayNano), true),
1525            Field::new(
1526                "x",
1527                DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
1528                true,
1529            ),
1530            Field::new(
1531                "x",
1532                DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
1533                true,
1534            ),
1535            Field::new(
1536                "x",
1537                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1538                true,
1539            ),
1540            Field::new(
1541                "x",
1542                DataType::List(Arc::new(Field::new(
1543                    "item",
1544                    DataType::Timestamp(TimeUnit::Second, None),
1545                    true,
1546                ))),
1547                true,
1548            ),
1549        ]);
1550
1551        let mut builder = ListBuilder::new(Int64Builder::new());
1552        builder.append_value([Some(1i64), None, Some(2)]);
1553        builder.append_null();
1554        builder.append_value([Some(-1i64), None, Some(-2)]);
1555        let i64_list_array = builder.finish();
1556
1557        let mut builder = ListBuilder::new(Float64Builder::new());
1558        builder.append_value([Some(1.0f64), None, Some(2.0)]);
1559        builder.append_null();
1560        builder.append_value([Some(-1.0f64), None, Some(-2.0)]);
1561        let f64_list_array = builder.finish();
1562
1563        let mut builder = ListBuilder::new(StringBuilder::new());
1564        builder.append_value([Some("a"), None, Some("b")]);
1565        builder.append_null();
1566        builder.append_value([Some("c"), None, Some("d")]);
1567        let string_list_array = builder.finish();
1568
1569        let mut builder = ListBuilder::new(TimestampSecondBuilder::new());
1570        builder.append_value([Some(1i64), None, Some(2)]);
1571        builder.append_null();
1572        builder.append_value([Some(3i64), None, Some(4)]);
1573        let timestamp_list_array = builder.finish();
1574
1575        let values = vec![
1576            Arc::new(NullVector::new(3)) as VectorRef,
1577            Arc::new(BooleanVector::from(vec![Some(true), Some(false), None])),
1578            Arc::new(UInt8Vector::from(vec![Some(u8::MAX), Some(u8::MIN), None])),
1579            Arc::new(UInt16Vector::from(vec![
1580                Some(u16::MAX),
1581                Some(u16::MIN),
1582                None,
1583            ])),
1584            Arc::new(UInt32Vector::from(vec![
1585                Some(u32::MAX),
1586                Some(u32::MIN),
1587                None,
1588            ])),
1589            Arc::new(UInt64Vector::from(vec![
1590                Some(u64::MAX),
1591                Some(u64::MIN),
1592                None,
1593            ])),
1594            Arc::new(Int8Vector::from(vec![Some(i8::MAX), Some(i8::MIN), None])),
1595            Arc::new(Int16Vector::from(vec![
1596                Some(i16::MAX),
1597                Some(i16::MIN),
1598                None,
1599            ])),
1600            Arc::new(Int32Vector::from(vec![
1601                Some(i32::MAX),
1602                Some(i32::MIN),
1603                None,
1604            ])),
1605            Arc::new(Int64Vector::from(vec![
1606                Some(i64::MAX),
1607                Some(i64::MIN),
1608                None,
1609            ])),
1610            Arc::new(Float32Vector::from(vec![
1611                None,
1612                Some(f32::MAX),
1613                Some(f32::MIN),
1614            ])),
1615            Arc::new(Float64Vector::from(vec![
1616                None,
1617                Some(f64::MAX),
1618                Some(f64::MIN),
1619            ])),
1620            Arc::new(StringVector::from(vec![
1621                None,
1622                Some("hello"),
1623                Some("greptime"),
1624            ])),
1625            Arc::new(BinaryVector::from(vec![
1626                None,
1627                Some("hello".as_bytes().to_vec()),
1628                Some("world".as_bytes().to_vec()),
1629            ])),
1630            Arc::new(DateVector::from(vec![Some(1001), None, Some(1)])),
1631            Arc::new(TimeSecondVector::from(vec![Some(1001), None, Some(1)])),
1632            Arc::new(TimestampSecondVector::from(vec![
1633                Some(1000001),
1634                None,
1635                Some(1),
1636            ])),
1637            Arc::new(IntervalYearMonthVector::from(vec![Some(1), None, Some(2)])),
1638            Arc::new(IntervalDayTimeVector::from(vec![
1639                Some(arrow::datatypes::IntervalDayTime::new(1, 1)),
1640                None,
1641                Some(arrow::datatypes::IntervalDayTime::new(2, 2)),
1642            ])),
1643            Arc::new(IntervalMonthDayNanoVector::from(vec![
1644                Some(arrow::datatypes::IntervalMonthDayNano::new(1, 1, 10)),
1645                None,
1646                Some(arrow::datatypes::IntervalMonthDayNano::new(2, 2, 20)),
1647            ])),
1648            Arc::new(ListVector::from(i64_list_array)),
1649            Arc::new(ListVector::from(f64_list_array)),
1650            Arc::new(ListVector::from(string_list_array)),
1651            Arc::new(ListVector::from(timestamp_list_array)),
1652        ];
1653        let record_batch =
1654            RecordBatch::new(Arc::new(arrow_schema.try_into().unwrap()), values).unwrap();
1655
1656        let query_context = QueryContextBuilder::default()
1657            .configuration_parameter(Default::default())
1658            .build()
1659            .into();
1660        let schema = Arc::new(schema);
1661
1662        let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
1663            .filter_map(|x| x.ok())
1664            .collect::<Vec<_>>();
1665        assert_eq!(rows.len(), 3);
1666        for row in rows {
1667            assert_eq!(row.field_count, schema.len() as i16);
1668        }
1669    }
1670
1671    #[test]
1672    fn test_invalid_parameter() {
1673        // test for refactor with PgErrorCode
1674        let msg = "invalid_parameter_count";
1675        let error = invalid_parameter_error(msg, None);
1676        if let PgWireError::UserError(value) = error {
1677            assert_eq!("ERROR", value.severity);
1678            assert_eq!("22023", value.code);
1679            assert_eq!(msg, value.message);
1680        } else {
1681            panic!("test_invalid_parameter failed");
1682        }
1683    }
1684}