servers/postgres/
types.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod bytea;
16mod datetime;
17mod error;
18mod interval;
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, AsArray};
24use arrow::datatypes::{
25    Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType,
26    DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, Int8Type, Int16Type,
27    Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
28    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
29    TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
30    TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
31};
32use arrow_schema::{DataType, IntervalUnit, TimeUnit};
33use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
34use common_decimal::Decimal128;
35use common_recordbatch::RecordBatch;
36use common_time::time::Time;
37use common_time::{
38    Date, Duration, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp,
39};
40use datafusion_common::ScalarValue;
41use datafusion_expr::LogicalPlan;
42use datatypes::arrow::datatypes::DataType as ArrowDataType;
43use datatypes::json::JsonStructureSettings;
44use datatypes::prelude::{ConcreteDataType, Value};
45use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
46use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
47use datatypes::value::StructValue;
48use pgwire::api::Type;
49use pgwire::api::portal::{Format, Portal};
50use pgwire::api::results::{DataRowEncoder, FieldInfo};
51use pgwire::error::{PgWireError, PgWireResult};
52use pgwire::messages::data::DataRow;
53use session::context::QueryContextRef;
54use session::session_config::PGByteaOutputValue;
55use snafu::ResultExt;
56
57use self::bytea::{EscapeOutputBytea, HexOutputBytea};
58use self::datetime::{StylingDate, StylingDateTime};
59pub use self::error::{PgErrorCode, PgErrorSeverity};
60use self::interval::PgInterval;
61use crate::SqlPlan;
62use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
63use crate::postgres::utils::convert_err;
64
65pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
66    origin
67        .column_schemas()
68        .iter()
69        .enumerate()
70        .map(|(idx, col)| {
71            Ok(FieldInfo::new(
72                col.name.clone(),
73                None,
74                None,
75                type_gt_to_pg(&col.data_type)?,
76                field_formats.format_for(idx),
77            ))
78        })
79        .collect::<Result<Vec<FieldInfo>>>()
80}
81
82/// this function will encode greptime's `StructValue` into PostgreSQL jsonb type
83///
84/// Note that greptimedb has different types of StructValue for storing json data,
85/// based on policy defined in `JsonStructureSettings`. But here the `StructValue`
86/// should be fully structured.
87///
88/// there are alternatives like records, arrays, etc. but there are also limitations:
89/// records: there is no support for include keys
90/// arrays: element in array must be the same type
91fn encode_struct(
92    _query_ctx: &QueryContextRef,
93    struct_value: StructValue,
94    builder: &mut DataRowEncoder,
95) -> PgWireResult<()> {
96    let encoding_setting = JsonStructureSettings::Structured(None);
97    let json_value = encoding_setting
98        .decode(Value::Struct(struct_value))
99        .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
100
101    builder.encode_field(&json_value)
102}
103
104fn encode_array(
105    query_ctx: &QueryContextRef,
106    array: ArrayRef,
107    builder: &mut DataRowEncoder,
108) -> PgWireResult<()> {
109    macro_rules! encode_primitive_array {
110        ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
111            let array = $array.iter().collect::<Vec<Option<$data_type>>>();
112            if array
113                .iter()
114                .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
115            {
116                builder.encode_field(
117                    &array
118                        .into_iter()
119                        .map(|x| x.map(|i| i as $lower_type))
120                        .collect::<Vec<Option<$lower_type>>>(),
121                )
122            } else {
123                builder.encode_field(
124                    &array
125                        .into_iter()
126                        .map(|x| x.map(|i| i as $upper_type))
127                        .collect::<Vec<Option<$upper_type>>>(),
128                )
129            }
130        }};
131    }
132
133    match array.data_type() {
134        DataType::Boolean => {
135            let array = array.as_boolean();
136            let array = array.iter().collect::<Vec<_>>();
137            builder.encode_field(&array)
138        }
139        DataType::Int8 => {
140            let array = array.as_primitive::<Int8Type>();
141            let array = array.iter().collect::<Vec<_>>();
142            builder.encode_field(&array)
143        }
144        DataType::Int16 => {
145            let array = array.as_primitive::<Int16Type>();
146            let array = array.iter().collect::<Vec<_>>();
147            builder.encode_field(&array)
148        }
149        DataType::Int32 => {
150            let array = array.as_primitive::<Int32Type>();
151            let array = array.iter().collect::<Vec<_>>();
152            builder.encode_field(&array)
153        }
154        DataType::Int64 => {
155            let array = array.as_primitive::<Int64Type>();
156            let array = array.iter().collect::<Vec<_>>();
157            builder.encode_field(&array)
158        }
159        DataType::UInt8 => {
160            let array = array.as_primitive::<UInt8Type>();
161            encode_primitive_array!(array, u8, i8, i16)
162        }
163        DataType::UInt16 => {
164            let array = array.as_primitive::<UInt16Type>();
165            encode_primitive_array!(array, u16, i16, i32)
166        }
167        DataType::UInt32 => {
168            let array = array.as_primitive::<UInt32Type>();
169            encode_primitive_array!(array, u32, i32, i64)
170        }
171        DataType::UInt64 => {
172            let array = array.as_primitive::<UInt64Type>();
173            let array = array.iter().collect::<Vec<_>>();
174            if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
175                builder.encode_field(
176                    &array
177                        .into_iter()
178                        .map(|x| x.map(|i| i as i64))
179                        .collect::<Vec<Option<i64>>>(),
180                )
181            } else {
182                builder.encode_field(
183                    &array
184                        .into_iter()
185                        .map(|x| x.map(|i| i.to_string()))
186                        .collect::<Vec<_>>(),
187                )
188            }
189        }
190        DataType::Float32 => {
191            let array = array.as_primitive::<Float32Type>();
192            let array = array.iter().collect::<Vec<_>>();
193            builder.encode_field(&array)
194        }
195        DataType::Float64 => {
196            let array = array.as_primitive::<Float64Type>();
197            let array = array.iter().collect::<Vec<_>>();
198            builder.encode_field(&array)
199        }
200        DataType::Binary => {
201            let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
202
203            let array = array.as_binary::<i32>();
204            match *bytea_output {
205                PGByteaOutputValue::ESCAPE => {
206                    let array = array
207                        .iter()
208                        .map(|v| v.map(EscapeOutputBytea))
209                        .collect::<Vec<_>>();
210                    builder.encode_field(&array)
211                }
212                PGByteaOutputValue::HEX => {
213                    let array = array
214                        .iter()
215                        .map(|v| v.map(HexOutputBytea))
216                        .collect::<Vec<_>>();
217                    builder.encode_field(&array)
218                }
219            }
220        }
221        DataType::Utf8 => {
222            let array = array.as_string::<i32>();
223            let array = array.into_iter().collect::<Vec<_>>();
224            builder.encode_field(&array)
225        }
226        DataType::LargeUtf8 => {
227            let array = array.as_string::<i64>();
228            let array = array.into_iter().collect::<Vec<_>>();
229            builder.encode_field(&array)
230        }
231        DataType::Utf8View => {
232            let array = array.as_string_view();
233            let array = array.into_iter().collect::<Vec<_>>();
234            builder.encode_field(&array)
235        }
236        DataType::Date32 | DataType::Date64 => {
237            let iter: Box<dyn Iterator<Item = Option<i32>>> =
238                if matches!(array.data_type(), DataType::Date32) {
239                    let array = array.as_primitive::<Date32Type>();
240                    Box::new(array.into_iter())
241                } else {
242                    let array = array.as_primitive::<Date64Type>();
243                    // `Date64` values are milliseconds representation of `Date32` values, according
244                    // to its specification. So we convert them to `Date32` values to process the
245                    // `Date64` array unified with `Date32` array.
246                    Box::new(
247                        array
248                            .into_iter()
249                            .map(|x| x.map(|i| (i / 86_400_000) as i32)),
250                    )
251                };
252            let array = iter
253                .into_iter()
254                .map(|v| match v {
255                    None => Ok(None),
256                    Some(v) => {
257                        if let Some(date) = Date::new(v).to_chrono_date() {
258                            let (style, order) =
259                                *query_ctx.configuration_parameter().pg_datetime_style();
260                            Ok(Some(StylingDate(date, style, order)))
261                        } else {
262                            Err(convert_err(Error::Internal {
263                                err_msg: format!("Failed to convert date to postgres type {v:?}",),
264                            }))
265                        }
266                    }
267                })
268                .collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
269            builder.encode_field(&array)
270        }
271        DataType::Timestamp(time_unit, _) => {
272            let array = match time_unit {
273                TimeUnit::Second => {
274                    let array = array.as_primitive::<TimestampSecondType>();
275                    array.into_iter().collect::<Vec<_>>()
276                }
277                TimeUnit::Millisecond => {
278                    let array = array.as_primitive::<TimestampMillisecondType>();
279                    array.into_iter().collect::<Vec<_>>()
280                }
281                TimeUnit::Microsecond => {
282                    let array = array.as_primitive::<TimestampMicrosecondType>();
283                    array.into_iter().collect::<Vec<_>>()
284                }
285                TimeUnit::Nanosecond => {
286                    let array = array.as_primitive::<TimestampNanosecondType>();
287                    array.into_iter().collect::<Vec<_>>()
288                }
289            };
290            let time_unit = time_unit.into();
291            let array = array
292                .into_iter()
293                .map(|v| match v {
294                    None => Ok(None),
295                    Some(v) => {
296                        let v = Timestamp::new(v, time_unit);
297                        if let Some(datetime) =
298                            v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
299                        {
300                            let (style, order) =
301                                *query_ctx.configuration_parameter().pg_datetime_style();
302                            Ok(Some(StylingDateTime(datetime, style, order)))
303                        } else {
304                            Err(convert_err(Error::Internal {
305                                err_msg: format!("Failed to convert date to postgres type {v:?}",),
306                            }))
307                        }
308                    }
309                })
310                .collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
311            builder.encode_field(&array)
312        }
313        DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
314            let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
315                TimeUnit::Second => {
316                    let array = array.as_primitive::<Time32SecondType>();
317                    Box::new(
318                        array
319                            .into_iter()
320                            .map(|v| v.map(|i| Time::new_second(i as i64))),
321                    )
322                }
323                TimeUnit::Millisecond => {
324                    let array = array.as_primitive::<Time32MillisecondType>();
325                    Box::new(
326                        array
327                            .into_iter()
328                            .map(|v| v.map(|i| Time::new_millisecond(i as i64))),
329                    )
330                }
331                TimeUnit::Microsecond => {
332                    let array = array.as_primitive::<Time64MicrosecondType>();
333                    Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
334                }
335                TimeUnit::Nanosecond => {
336                    let array = array.as_primitive::<Time64NanosecondType>();
337                    Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
338                }
339            };
340            let array = iter
341                .into_iter()
342                .map(|v| v.and_then(|v| v.to_chrono_time()))
343                .collect::<Vec<Option<NaiveTime>>>();
344            builder.encode_field(&array)
345        }
346        DataType::Interval(interval_unit) => {
347            let array = match interval_unit {
348                IntervalUnit::YearMonth => {
349                    let array = array.as_primitive::<IntervalYearMonthType>();
350                    array
351                        .into_iter()
352                        .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
353                        .collect::<Vec<_>>()
354                }
355                IntervalUnit::DayTime => {
356                    let array = array.as_primitive::<IntervalDayTimeType>();
357                    array
358                        .into_iter()
359                        .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
360                        .collect::<Vec<_>>()
361                }
362                IntervalUnit::MonthDayNano => {
363                    let array = array.as_primitive::<IntervalMonthDayNanoType>();
364                    array
365                        .into_iter()
366                        .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
367                        .collect::<Vec<_>>()
368                }
369            };
370            builder.encode_field(&array)
371        }
372        DataType::Decimal128(precision, scale) => {
373            let array = array.as_primitive::<Decimal128Type>();
374            let array = array
375                .into_iter()
376                .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
377                .collect::<Vec<_>>();
378            builder.encode_field(&array)
379        }
380        _ => Err(convert_err(Error::Internal {
381            err_msg: format!(
382                "cannot write array type {:?} in postgres protocol: unimplemented",
383                array.data_type()
384            ),
385        })),
386    }
387}
388
389pub(crate) struct RecordBatchRowIterator {
390    query_ctx: QueryContextRef,
391    pg_schema: Arc<Vec<FieldInfo>>,
392    schema: SchemaRef,
393    record_batch: arrow::record_batch::RecordBatch,
394    i: usize,
395}
396
397impl Iterator for RecordBatchRowIterator {
398    type Item = PgWireResult<DataRow>;
399
400    fn next(&mut self) -> Option<Self::Item> {
401        if self.i < self.record_batch.num_rows() {
402            let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
403            if let Err(e) = self.encode_row(self.i, &mut encoder) {
404                return Some(Err(e));
405            }
406            self.i += 1;
407            Some(encoder.finish())
408        } else {
409            None
410        }
411    }
412}
413
414impl RecordBatchRowIterator {
415    pub(crate) fn new(
416        query_ctx: QueryContextRef,
417        pg_schema: Arc<Vec<FieldInfo>>,
418        record_batch: RecordBatch,
419    ) -> Self {
420        let schema = record_batch.schema.clone();
421        let record_batch = record_batch.into_df_record_batch();
422        Self {
423            query_ctx,
424            pg_schema,
425            schema,
426            record_batch,
427            i: 0,
428        }
429    }
430
431    fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
432        for (j, column) in self.record_batch.columns().iter().enumerate() {
433            if column.is_null(i) {
434                encoder.encode_field(&None::<&i8>)?;
435                continue;
436            }
437
438            match column.data_type() {
439                DataType::Null => {
440                    encoder.encode_field(&None::<&i8>)?;
441                }
442                DataType::Boolean => {
443                    let array = column.as_boolean();
444                    encoder.encode_field(&array.value(i))?;
445                }
446                DataType::UInt8 => {
447                    let array = column.as_primitive::<UInt8Type>();
448                    let value = array.value(i);
449                    if value <= i8::MAX as u8 {
450                        encoder.encode_field(&(value as i8))?;
451                    } else {
452                        encoder.encode_field(&(value as i16))?;
453                    }
454                }
455                DataType::UInt16 => {
456                    let array = column.as_primitive::<UInt16Type>();
457                    let value = array.value(i);
458                    if value <= i16::MAX as u16 {
459                        encoder.encode_field(&(value as i16))?;
460                    } else {
461                        encoder.encode_field(&(value as i32))?;
462                    }
463                }
464                DataType::UInt32 => {
465                    let array = column.as_primitive::<UInt32Type>();
466                    let value = array.value(i);
467                    if value <= i32::MAX as u32 {
468                        encoder.encode_field(&(value as i32))?;
469                    } else {
470                        encoder.encode_field(&(value as i64))?;
471                    }
472                }
473                DataType::UInt64 => {
474                    let array = column.as_primitive::<UInt64Type>();
475                    let value = array.value(i);
476                    if value <= i64::MAX as u64 {
477                        encoder.encode_field(&(value as i64))?;
478                    } else {
479                        encoder.encode_field(&value.to_string())?;
480                    }
481                }
482                DataType::Int8 => {
483                    let array = column.as_primitive::<Int8Type>();
484                    encoder.encode_field(&array.value(i))?;
485                }
486                DataType::Int16 => {
487                    let array = column.as_primitive::<Int16Type>();
488                    encoder.encode_field(&array.value(i))?;
489                }
490                DataType::Int32 => {
491                    let array = column.as_primitive::<Int32Type>();
492                    encoder.encode_field(&array.value(i))?;
493                }
494                DataType::Int64 => {
495                    let array = column.as_primitive::<Int64Type>();
496                    encoder.encode_field(&array.value(i))?;
497                }
498                DataType::Float32 => {
499                    let array = column.as_primitive::<Float32Type>();
500                    encoder.encode_field(&array.value(i))?;
501                }
502                DataType::Float64 => {
503                    let array = column.as_primitive::<Float64Type>();
504                    encoder.encode_field(&array.value(i))?;
505                }
506                DataType::Utf8 => {
507                    let array = column.as_string::<i32>();
508                    let value = array.value(i);
509                    encoder.encode_field(&value)?;
510                }
511                DataType::Utf8View => {
512                    let array = column.as_string_view();
513                    let value = array.value(i);
514                    encoder.encode_field(&value)?;
515                }
516                DataType::LargeUtf8 => {
517                    let array = column.as_string::<i64>();
518                    let value = array.value(i);
519                    encoder.encode_field(&value)?;
520                }
521                DataType::Binary => {
522                    let array = column.as_binary::<i32>();
523                    let v = array.value(i);
524                    encode_bytes(
525                        &self.schema.column_schemas()[j],
526                        v,
527                        encoder,
528                        &self.query_ctx,
529                    )?;
530                }
531                DataType::BinaryView => {
532                    let array = column.as_binary_view();
533                    let v = array.value(i);
534                    encode_bytes(
535                        &self.schema.column_schemas()[j],
536                        v,
537                        encoder,
538                        &self.query_ctx,
539                    )?;
540                }
541                DataType::LargeBinary => {
542                    let array = column.as_binary::<i64>();
543                    let v = array.value(i);
544                    encode_bytes(
545                        &self.schema.column_schemas()[j],
546                        v,
547                        encoder,
548                        &self.query_ctx,
549                    )?;
550                }
551                DataType::Date32 | DataType::Date64 => {
552                    let v = if matches!(column.data_type(), DataType::Date32) {
553                        let array = column.as_primitive::<Date32Type>();
554                        array.value(i)
555                    } else {
556                        let array = column.as_primitive::<Date64Type>();
557                        // `Date64` values are milliseconds representation of `Date32` values,
558                        // according to its specification. So we convert the `Date64` value here to
559                        // the `Date32` value to process them unified.
560                        (array.value(i) / 86_400_000) as i32
561                    };
562                    let v = Date::new(v);
563                    let date = v.to_chrono_date().map(|v| {
564                        let (style, order) =
565                            *self.query_ctx.configuration_parameter().pg_datetime_style();
566                        StylingDate(v, style, order)
567                    });
568                    encoder.encode_field(&date)?;
569                }
570                DataType::Timestamp(time_unit, _) => {
571                    let v = match time_unit {
572                        TimeUnit::Second => {
573                            let array = column.as_primitive::<TimestampSecondType>();
574                            array.value(i)
575                        }
576                        TimeUnit::Millisecond => {
577                            let array = column.as_primitive::<TimestampMillisecondType>();
578                            array.value(i)
579                        }
580                        TimeUnit::Microsecond => {
581                            let array = column.as_primitive::<TimestampMicrosecondType>();
582                            array.value(i)
583                        }
584                        TimeUnit::Nanosecond => {
585                            let array = column.as_primitive::<TimestampNanosecondType>();
586                            array.value(i)
587                        }
588                    };
589                    let v = Timestamp::new(v, time_unit.into());
590                    let datetime = v
591                        .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
592                        .map(|v| {
593                            let (style, order) =
594                                *self.query_ctx.configuration_parameter().pg_datetime_style();
595                            StylingDateTime(v, style, order)
596                        });
597                    encoder.encode_field(&datetime)?;
598                }
599                DataType::Interval(interval_unit) => match interval_unit {
600                    IntervalUnit::YearMonth => {
601                        let array = column.as_primitive::<IntervalYearMonthType>();
602                        let v: IntervalYearMonth = array.value(i).into();
603                        encoder.encode_field(&PgInterval::from(v))?;
604                    }
605                    IntervalUnit::DayTime => {
606                        let array = column.as_primitive::<IntervalDayTimeType>();
607                        let v: IntervalDayTime = array.value(i).into();
608                        encoder.encode_field(&PgInterval::from(v))?;
609                    }
610                    IntervalUnit::MonthDayNano => {
611                        let array = column.as_primitive::<IntervalMonthDayNanoType>();
612                        let v: IntervalMonthDayNano = array.value(i).into();
613                        encoder.encode_field(&PgInterval::from(v))?;
614                    }
615                },
616                DataType::Duration(time_unit) => {
617                    let v = match time_unit {
618                        TimeUnit::Second => {
619                            let array = column.as_primitive::<DurationSecondType>();
620                            array.value(i)
621                        }
622                        TimeUnit::Millisecond => {
623                            let array = column.as_primitive::<DurationMillisecondType>();
624                            array.value(i)
625                        }
626                        TimeUnit::Microsecond => {
627                            let array = column.as_primitive::<DurationMicrosecondType>();
628                            array.value(i)
629                        }
630                        TimeUnit::Nanosecond => {
631                            let array = column.as_primitive::<DurationNanosecondType>();
632                            array.value(i)
633                        }
634                    };
635                    let d = Duration::new(v, time_unit.into());
636                    match PgInterval::try_from(d) {
637                        Ok(i) => encoder.encode_field(&i)?,
638                        Err(e) => {
639                            return Err(convert_err(Error::Internal {
640                                err_msg: e.to_string(),
641                            }));
642                        }
643                    }
644                }
645                DataType::List(_) => {
646                    let array = column.as_list::<i32>();
647                    let items = array.value(i);
648                    encode_array(&self.query_ctx, items, encoder)?;
649                }
650                DataType::Struct(_) => {
651                    encode_struct(&self.query_ctx, Default::default(), encoder)?;
652                }
653                DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
654                    let v = match time_unit {
655                        TimeUnit::Second => {
656                            let array = column.as_primitive::<Time32SecondType>();
657                            Time::new_second(array.value(i) as i64)
658                        }
659                        TimeUnit::Millisecond => {
660                            let array = column.as_primitive::<Time32MillisecondType>();
661                            Time::new_millisecond(array.value(i) as i64)
662                        }
663                        TimeUnit::Microsecond => {
664                            let array = column.as_primitive::<Time64MicrosecondType>();
665                            Time::new_microsecond(array.value(i))
666                        }
667                        TimeUnit::Nanosecond => {
668                            let array = column.as_primitive::<Time64NanosecondType>();
669                            Time::new_nanosecond(array.value(i))
670                        }
671                    };
672                    encoder.encode_field(&v.to_chrono_time())?;
673                }
674                DataType::Decimal128(precision, scale) => {
675                    let array = column.as_primitive::<Decimal128Type>();
676                    let v = Decimal128::new(array.value(i), *precision, *scale);
677                    encoder.encode_field(&v.to_string())?;
678                }
679                _ => {
680                    return Err(convert_err(Error::Internal {
681                        err_msg: format!(
682                            "cannot convert datatype {} to postgres",
683                            column.data_type()
684                        ),
685                    }));
686                }
687            }
688        }
689        Ok(())
690    }
691}
692
693fn encode_bytes(
694    schema: &ColumnSchema,
695    v: &[u8],
696    encoder: &mut DataRowEncoder,
697    query_ctx: &QueryContextRef,
698) -> PgWireResult<()> {
699    if let ConcreteDataType::Json(_) = &schema.data_type {
700        let s = jsonb_to_string(v).map_err(convert_err)?;
701        encoder.encode_field(&s)
702    } else {
703        let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
704        match *bytea_output {
705            PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
706            PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
707        }
708    }
709}
710
711pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
712    match origin {
713        &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
714        &ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
715        &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
716        &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
717        &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
718        &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
719        &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
720        &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
721        &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
722        &ConcreteDataType::String(_) => Ok(Type::VARCHAR),
723        &ConcreteDataType::Date(_) => Ok(Type::DATE),
724        &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
725        &ConcreteDataType::Time(_) => Ok(Type::TIME),
726        &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL),
727        &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
728        &ConcreteDataType::Json(_) => Ok(Type::JSON),
729        ConcreteDataType::List(list) => match list.item_type() {
730            &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
731            &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
732            &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
733            &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
734            &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
735            &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
736            &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
737            &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
738            &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
739            &ConcreteDataType::String(_) => Ok(Type::VARCHAR_ARRAY),
740            &ConcreteDataType::Date(_) => Ok(Type::DATE_ARRAY),
741            &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP_ARRAY),
742            &ConcreteDataType::Time(_) => Ok(Type::TIME_ARRAY),
743            &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL_ARRAY),
744            &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC_ARRAY),
745            &ConcreteDataType::Json(_) => Ok(Type::JSON_ARRAY),
746            &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL_ARRAY),
747            &ConcreteDataType::Struct(_) => Ok(Type::JSON_ARRAY),
748            &ConcreteDataType::Dictionary(_)
749            | &ConcreteDataType::Vector(_)
750            | &ConcreteDataType::List(_) => server_error::UnsupportedDataTypeSnafu {
751                data_type: origin,
752                reason: "not implemented",
753            }
754            .fail(),
755        },
756        &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
757            data_type: origin,
758            reason: "not implemented",
759        }
760        .fail(),
761        &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL),
762        &ConcreteDataType::Struct(_) => Ok(Type::JSON),
763    }
764}
765
766#[allow(dead_code)]
767pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
768    // Note that we only support a small amount of pg data types
769    match origin {
770        &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
771        &Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
772        &Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
773        &Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
774        &Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
775        &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
776        &Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
777            common_time::timestamp::TimeUnit::Millisecond,
778        )),
779        &Type::DATE => Ok(ConcreteDataType::date_datatype()),
780        &Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
781            common_time::timestamp::TimeUnit::Microsecond,
782        )),
783        &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
784            ConcreteDataType::int8_datatype(),
785        ))),
786        &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
787            ConcreteDataType::int16_datatype(),
788        ))),
789        &Type::INT4_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
790            ConcreteDataType::int32_datatype(),
791        ))),
792        &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
793            ConcreteDataType::int64_datatype(),
794        ))),
795        &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
796            ConcreteDataType::string_datatype(),
797        ))),
798        _ => server_error::InternalSnafu {
799            err_msg: format!("unimplemented datatype {origin:?}"),
800        }
801        .fail(),
802    }
803}
804
805pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
806    // the index is managed from portal's parameters count so it's safe to
807    // unwrap here.
808    let param_type = portal.statement.parameter_types.get(idx).unwrap();
809    match param_type {
810        &Type::VARCHAR | &Type::TEXT => Ok(format!(
811            "'{}'",
812            portal
813                .parameter::<String>(idx, param_type)?
814                .as_deref()
815                .unwrap_or("")
816        )),
817        &Type::BOOL => Ok(portal
818            .parameter::<bool>(idx, param_type)?
819            .map(|v| v.to_string())
820            .unwrap_or_else(|| "".to_owned())),
821        &Type::INT4 => Ok(portal
822            .parameter::<i32>(idx, param_type)?
823            .map(|v| v.to_string())
824            .unwrap_or_else(|| "".to_owned())),
825        &Type::INT8 => Ok(portal
826            .parameter::<i64>(idx, param_type)?
827            .map(|v| v.to_string())
828            .unwrap_or_else(|| "".to_owned())),
829        &Type::FLOAT4 => Ok(portal
830            .parameter::<f32>(idx, param_type)?
831            .map(|v| v.to_string())
832            .unwrap_or_else(|| "".to_owned())),
833        &Type::FLOAT8 => Ok(portal
834            .parameter::<f64>(idx, param_type)?
835            .map(|v| v.to_string())
836            .unwrap_or_else(|| "".to_owned())),
837        &Type::DATE => Ok(portal
838            .parameter::<NaiveDate>(idx, param_type)?
839            .map(|v| v.format("%Y-%m-%d").to_string())
840            .unwrap_or_else(|| "".to_owned())),
841        &Type::TIMESTAMP => Ok(portal
842            .parameter::<NaiveDateTime>(idx, param_type)?
843            .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
844            .unwrap_or_else(|| "".to_owned())),
845        &Type::INTERVAL => Ok(portal
846            .parameter::<PgInterval>(idx, param_type)?
847            .map(|v| v.to_string())
848            .unwrap_or_else(|| "".to_owned())),
849        _ => Err(invalid_parameter_error(
850            "unsupported_parameter_type",
851            Some(param_type.to_string()),
852        )),
853    }
854}
855
856pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
857    let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
858    error_info.detail = detail;
859    PgWireError::UserError(Box::new(error_info))
860}
861
862fn to_timestamp_scalar_value<T>(
863    data: Option<T>,
864    unit: &TimestampType,
865    ctype: &ConcreteDataType,
866) -> PgWireResult<ScalarValue>
867where
868    T: Into<i64>,
869{
870    if let Some(n) = data {
871        Value::Timestamp(unit.create_timestamp(n.into()))
872            .try_to_scalar_value(ctype)
873            .map_err(convert_err)
874    } else {
875        Ok(ScalarValue::Null)
876    }
877}
878
879pub(super) fn parameters_to_scalar_values(
880    plan: &LogicalPlan,
881    portal: &Portal<SqlPlan>,
882) -> PgWireResult<Vec<ScalarValue>> {
883    let param_count = portal.parameter_len();
884    let mut results = Vec::with_capacity(param_count);
885
886    let client_param_types = &portal.statement.parameter_types;
887    let param_types = plan
888        .get_parameter_types()
889        .context(DataFusionSnafu)
890        .map_err(convert_err)?
891        .into_iter()
892        .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
893        .collect::<HashMap<_, _>>();
894
895    for idx in 0..param_count {
896        let server_type = param_types
897            .get(&format!("${}", idx + 1))
898            .and_then(|t| t.as_ref());
899
900        let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
901            client_given_type.clone()
902        } else if let Some(server_provided_type) = &server_type {
903            type_gt_to_pg(server_provided_type).map_err(convert_err)?
904        } else {
905            return Err(invalid_parameter_error(
906                "unknown_parameter_type",
907                Some(format!(
908                    "Cannot get parameter type information for parameter {}",
909                    idx
910                )),
911            ));
912        };
913
914        let value = match &client_type {
915            &Type::VARCHAR | &Type::TEXT => {
916                let data = portal.parameter::<String>(idx, &client_type)?;
917                if let Some(server_type) = &server_type {
918                    match server_type {
919                        ConcreteDataType::String(t) => {
920                            if t.is_large() {
921                                ScalarValue::LargeUtf8(data)
922                            } else {
923                                ScalarValue::Utf8(data)
924                            }
925                        }
926                        _ => {
927                            return Err(invalid_parameter_error(
928                                "invalid_parameter_type",
929                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
930                            ));
931                        }
932                    }
933                } else {
934                    ScalarValue::Utf8(data)
935                }
936            }
937            &Type::BOOL => {
938                let data = portal.parameter::<bool>(idx, &client_type)?;
939                if let Some(server_type) = &server_type {
940                    match server_type {
941                        ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
942                        _ => {
943                            return Err(invalid_parameter_error(
944                                "invalid_parameter_type",
945                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
946                            ));
947                        }
948                    }
949                } else {
950                    ScalarValue::Boolean(data)
951                }
952            }
953            &Type::INT2 => {
954                let data = portal.parameter::<i16>(idx, &client_type)?;
955                if let Some(server_type) = &server_type {
956                    match server_type {
957                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
958                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
959                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
960                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
961                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
962                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
963                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
964                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
965                        ConcreteDataType::Timestamp(unit) => {
966                            to_timestamp_scalar_value(data, unit, server_type)?
967                        }
968                        _ => {
969                            return Err(invalid_parameter_error(
970                                "invalid_parameter_type",
971                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
972                            ));
973                        }
974                    }
975                } else {
976                    ScalarValue::Int16(data)
977                }
978            }
979            &Type::INT4 => {
980                let data = portal.parameter::<i32>(idx, &client_type)?;
981                if let Some(server_type) = &server_type {
982                    match server_type {
983                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
984                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
985                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
986                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
987                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
988                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
989                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
990                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
991                        ConcreteDataType::Timestamp(unit) => {
992                            to_timestamp_scalar_value(data, unit, server_type)?
993                        }
994                        _ => {
995                            return Err(invalid_parameter_error(
996                                "invalid_parameter_type",
997                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
998                            ));
999                        }
1000                    }
1001                } else {
1002                    ScalarValue::Int32(data)
1003                }
1004            }
1005            &Type::INT8 => {
1006                let data = portal.parameter::<i64>(idx, &client_type)?;
1007                if let Some(server_type) = &server_type {
1008                    match server_type {
1009                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1010                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1011                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1012                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
1013                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1014                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1015                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1016                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1017                        ConcreteDataType::Timestamp(unit) => {
1018                            to_timestamp_scalar_value(data, unit, server_type)?
1019                        }
1020                        _ => {
1021                            return Err(invalid_parameter_error(
1022                                "invalid_parameter_type",
1023                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1024                            ));
1025                        }
1026                    }
1027                } else {
1028                    ScalarValue::Int64(data)
1029                }
1030            }
1031            &Type::FLOAT4 => {
1032                let data = portal.parameter::<f32>(idx, &client_type)?;
1033                if let Some(server_type) = &server_type {
1034                    match server_type {
1035                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1036                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1037                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1038                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1039                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1040                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1041                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1042                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1043                        ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
1044                        ConcreteDataType::Float64(_) => {
1045                            ScalarValue::Float64(data.map(|n| n as f64))
1046                        }
1047                        _ => {
1048                            return Err(invalid_parameter_error(
1049                                "invalid_parameter_type",
1050                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1051                            ));
1052                        }
1053                    }
1054                } else {
1055                    ScalarValue::Float32(data)
1056                }
1057            }
1058            &Type::FLOAT8 => {
1059                let data = portal.parameter::<f64>(idx, &client_type)?;
1060                if let Some(server_type) = &server_type {
1061                    match server_type {
1062                        ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1063                        ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1064                        ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1065                        ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1066                        ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1067                        ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1068                        ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1069                        ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1070                        ConcreteDataType::Float32(_) => {
1071                            ScalarValue::Float32(data.map(|n| n as f32))
1072                        }
1073                        ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
1074                        _ => {
1075                            return Err(invalid_parameter_error(
1076                                "invalid_parameter_type",
1077                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1078                            ));
1079                        }
1080                    }
1081                } else {
1082                    ScalarValue::Float64(data)
1083                }
1084            }
1085            &Type::TIMESTAMP => {
1086                let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
1087                if let Some(server_type) = &server_type {
1088                    match server_type {
1089                        ConcreteDataType::Timestamp(unit) => match *unit {
1090                            TimestampType::Second(_) => ScalarValue::TimestampSecond(
1091                                data.map(|ts| ts.and_utc().timestamp()),
1092                                None,
1093                            ),
1094                            TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1095                                data.map(|ts| ts.and_utc().timestamp_millis()),
1096                                None,
1097                            ),
1098                            TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1099                                data.map(|ts| ts.and_utc().timestamp_micros()),
1100                                None,
1101                            ),
1102                            TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1103                                data.map(|ts| ts.and_utc().timestamp_micros()),
1104                                None,
1105                            ),
1106                        },
1107                        _ => {
1108                            return Err(invalid_parameter_error(
1109                                "invalid_parameter_type",
1110                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1111                            ));
1112                        }
1113                    }
1114                } else {
1115                    ScalarValue::TimestampMillisecond(
1116                        data.map(|ts| ts.and_utc().timestamp_millis()),
1117                        None,
1118                    )
1119                }
1120            }
1121            &Type::DATE => {
1122                let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
1123                if let Some(server_type) = &server_type {
1124                    match server_type {
1125                        ConcreteDataType::Date(_) => ScalarValue::Date32(
1126                            data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1127                        ),
1128                        _ => {
1129                            return Err(invalid_parameter_error(
1130                                "invalid_parameter_type",
1131                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1132                            ));
1133                        }
1134                    }
1135                } else {
1136                    ScalarValue::Date32(
1137                        data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1138                    )
1139                }
1140            }
1141            &Type::INTERVAL => {
1142                let data = portal.parameter::<PgInterval>(idx, &client_type)?;
1143                if let Some(server_type) = &server_type {
1144                    match server_type {
1145                        ConcreteDataType::Interval(IntervalType::YearMonth(_)) => {
1146                            ScalarValue::IntervalYearMonth(
1147                                data.map(|i| {
1148                                    if i.days != 0 || i.microseconds != 0 {
1149                                        Err(invalid_parameter_error(
1150                                            "invalid_parameter_type",
1151                                            Some(format!(
1152                                                "Expected: {}, found: {}",
1153                                                server_type, client_type
1154                                            )),
1155                                        ))
1156                                    } else {
1157                                        Ok(IntervalYearMonth::new(i.months).to_i32())
1158                                    }
1159                                })
1160                                .transpose()?,
1161                            )
1162                        }
1163                        ConcreteDataType::Interval(IntervalType::DayTime(_)) => {
1164                            ScalarValue::IntervalDayTime(
1165                                data.map(|i| {
1166                                    if i.months != 0 || i.microseconds % 1000 != 0 {
1167                                        Err(invalid_parameter_error(
1168                                            "invalid_parameter_type",
1169                                            Some(format!(
1170                                                "Expected: {}, found: {}",
1171                                                server_type, client_type
1172                                            )),
1173                                        ))
1174                                    } else {
1175                                        Ok(IntervalDayTime::new(
1176                                            i.days,
1177                                            (i.microseconds / 1000) as i32,
1178                                        )
1179                                        .into())
1180                                    }
1181                                })
1182                                .transpose()?,
1183                            )
1184                        }
1185                        ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
1186                            ScalarValue::IntervalMonthDayNano(
1187                                data.map(|i| IntervalMonthDayNano::from(i).into()),
1188                            )
1189                        }
1190                        _ => {
1191                            return Err(invalid_parameter_error(
1192                                "invalid_parameter_type",
1193                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1194                            ));
1195                        }
1196                    }
1197                } else {
1198                    ScalarValue::IntervalMonthDayNano(
1199                        data.map(|i| IntervalMonthDayNano::from(i).into()),
1200                    )
1201                }
1202            }
1203            &Type::BYTEA => {
1204                let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
1205                if let Some(server_type) = &server_type {
1206                    match server_type {
1207                        ConcreteDataType::String(t) => {
1208                            let s = data.map(|d| String::from_utf8_lossy(&d).to_string());
1209                            if t.is_large() {
1210                                ScalarValue::LargeUtf8(s)
1211                            } else {
1212                                ScalarValue::Utf8(s)
1213                            }
1214                        }
1215                        ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
1216                        _ => {
1217                            return Err(invalid_parameter_error(
1218                                "invalid_parameter_type",
1219                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1220                            ));
1221                        }
1222                    }
1223                } else {
1224                    ScalarValue::Binary(data)
1225                }
1226            }
1227            &Type::JSONB => {
1228                let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
1229                if let Some(server_type) = &server_type {
1230                    match server_type {
1231                        ConcreteDataType::Binary(_) => {
1232                            ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1233                        }
1234                        _ => {
1235                            return Err(invalid_parameter_error(
1236                                "invalid_parameter_type",
1237                                Some(format!("Expected: {}, found: {}", server_type, client_type)),
1238                            ));
1239                        }
1240                    }
1241                } else {
1242                    ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1243                }
1244            }
1245            &Type::INT2_ARRAY => {
1246                let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
1247                if let Some(data) = data {
1248                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1249                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
1250                } else {
1251                    ScalarValue::Null
1252                }
1253            }
1254            &Type::INT4_ARRAY => {
1255                let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
1256                if let Some(data) = data {
1257                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1258                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
1259                } else {
1260                    ScalarValue::Null
1261                }
1262            }
1263            &Type::INT8_ARRAY => {
1264                let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
1265                if let Some(data) = data {
1266                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1267                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
1268                } else {
1269                    ScalarValue::Null
1270                }
1271            }
1272            &Type::VARCHAR_ARRAY => {
1273                let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
1274                if let Some(data) = data {
1275                    let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1276                    ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
1277                } else {
1278                    ScalarValue::Null
1279                }
1280            }
1281            _ => Err(invalid_parameter_error(
1282                "unsupported_parameter_value",
1283                Some(format!("Found type: {}", client_type)),
1284            ))?,
1285        };
1286
1287        results.push(value);
1288    }
1289
1290    Ok(results)
1291}
1292
1293pub(super) fn param_types_to_pg_types(
1294    param_types: &HashMap<String, Option<ConcreteDataType>>,
1295) -> Result<Vec<Type>> {
1296    let param_count = param_types.len();
1297    let mut types = Vec::with_capacity(param_count);
1298    for i in 0..param_count {
1299        if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
1300            let pg_type = type_gt_to_pg(param_type)?;
1301            types.push(pg_type);
1302        } else {
1303            types.push(Type::UNKNOWN);
1304        }
1305    }
1306    Ok(types)
1307}
1308
1309#[cfg(test)]
1310mod test {
1311    use std::sync::Arc;
1312
1313    use arrow::array::{
1314        Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
1315    };
1316    use arrow_schema::Field;
1317    use datatypes::schema::{ColumnSchema, Schema};
1318    use datatypes::vectors::{
1319        BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
1320        Int16Vector, Int32Vector, Int64Vector, IntervalDayTimeVector, IntervalMonthDayNanoVector,
1321        IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
1322        TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
1323    };
1324    use pgwire::api::Type;
1325    use pgwire::api::results::{FieldFormat, FieldInfo};
1326    use session::context::QueryContextBuilder;
1327
1328    use super::*;
1329
1330    #[test]
1331    fn test_schema_convert() {
1332        let column_schemas = vec![
1333            ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
1334            ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
1335            ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
1336            ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
1337            ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
1338            ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
1339            ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
1340            ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
1341            ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
1342            ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
1343            ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
1344            ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
1345            ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
1346            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1347            ColumnSchema::new(
1348                "timestamps",
1349                ConcreteDataType::timestamp_millisecond_datatype(),
1350                true,
1351            ),
1352            ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
1353            ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
1354            ColumnSchema::new(
1355                "intervals",
1356                ConcreteDataType::interval_month_day_nano_datatype(),
1357                true,
1358            ),
1359        ];
1360        let pg_field_info = vec![
1361            FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1362            FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1363            FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1364            FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1365            FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1366            FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1367            FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1368            FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1369            FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1370            FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1371            FieldInfo::new(
1372                "float32s".into(),
1373                None,
1374                None,
1375                Type::FLOAT4,
1376                FieldFormat::Text,
1377            ),
1378            FieldInfo::new(
1379                "float64s".into(),
1380                None,
1381                None,
1382                Type::FLOAT8,
1383                FieldFormat::Text,
1384            ),
1385            FieldInfo::new(
1386                "binaries".into(),
1387                None,
1388                None,
1389                Type::BYTEA,
1390                FieldFormat::Text,
1391            ),
1392            FieldInfo::new(
1393                "strings".into(),
1394                None,
1395                None,
1396                Type::VARCHAR,
1397                FieldFormat::Text,
1398            ),
1399            FieldInfo::new(
1400                "timestamps".into(),
1401                None,
1402                None,
1403                Type::TIMESTAMP,
1404                FieldFormat::Text,
1405            ),
1406            FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1407            FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1408            FieldInfo::new(
1409                "intervals".into(),
1410                None,
1411                None,
1412                Type::INTERVAL,
1413                FieldFormat::Text,
1414            ),
1415        ];
1416        let schema = Schema::new(column_schemas);
1417        let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
1418        assert_eq!(fs, pg_field_info);
1419    }
1420
1421    #[test]
1422    fn test_encode_text_format_data() {
1423        let schema = vec![
1424            FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1425            FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1426            FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1427            FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1428            FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1429            FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1430            FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1431            FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1432            FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1433            FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1434            FieldInfo::new(
1435                "float32s".into(),
1436                None,
1437                None,
1438                Type::FLOAT4,
1439                FieldFormat::Text,
1440            ),
1441            FieldInfo::new(
1442                "float64s".into(),
1443                None,
1444                None,
1445                Type::FLOAT8,
1446                FieldFormat::Text,
1447            ),
1448            FieldInfo::new(
1449                "strings".into(),
1450                None,
1451                None,
1452                Type::VARCHAR,
1453                FieldFormat::Text,
1454            ),
1455            FieldInfo::new(
1456                "binaries".into(),
1457                None,
1458                None,
1459                Type::BYTEA,
1460                FieldFormat::Text,
1461            ),
1462            FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1463            FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1464            FieldInfo::new(
1465                "timestamps".into(),
1466                None,
1467                None,
1468                Type::TIMESTAMP,
1469                FieldFormat::Text,
1470            ),
1471            FieldInfo::new(
1472                "interval_year_month".into(),
1473                None,
1474                None,
1475                Type::INTERVAL,
1476                FieldFormat::Text,
1477            ),
1478            FieldInfo::new(
1479                "interval_day_time".into(),
1480                None,
1481                None,
1482                Type::INTERVAL,
1483                FieldFormat::Text,
1484            ),
1485            FieldInfo::new(
1486                "interval_month_day_nano".into(),
1487                None,
1488                None,
1489                Type::INTERVAL,
1490                FieldFormat::Text,
1491            ),
1492            FieldInfo::new(
1493                "int_list".into(),
1494                None,
1495                None,
1496                Type::INT8_ARRAY,
1497                FieldFormat::Text,
1498            ),
1499            FieldInfo::new(
1500                "float_list".into(),
1501                None,
1502                None,
1503                Type::FLOAT8_ARRAY,
1504                FieldFormat::Text,
1505            ),
1506            FieldInfo::new(
1507                "string_list".into(),
1508                None,
1509                None,
1510                Type::VARCHAR_ARRAY,
1511                FieldFormat::Text,
1512            ),
1513            FieldInfo::new(
1514                "timestamp_list".into(),
1515                None,
1516                None,
1517                Type::TIMESTAMP_ARRAY,
1518                FieldFormat::Text,
1519            ),
1520        ];
1521
1522        let arrow_schema = arrow_schema::Schema::new(vec![
1523            Field::new("x", DataType::Null, true),
1524            Field::new("x", DataType::Boolean, true),
1525            Field::new("x", DataType::UInt8, true),
1526            Field::new("x", DataType::UInt16, true),
1527            Field::new("x", DataType::UInt32, true),
1528            Field::new("x", DataType::UInt64, true),
1529            Field::new("x", DataType::Int8, true),
1530            Field::new("x", DataType::Int16, true),
1531            Field::new("x", DataType::Int32, true),
1532            Field::new("x", DataType::Int64, true),
1533            Field::new("x", DataType::Float32, true),
1534            Field::new("x", DataType::Float64, true),
1535            Field::new("x", DataType::Utf8, true),
1536            Field::new("x", DataType::Binary, true),
1537            Field::new("x", DataType::Date32, true),
1538            Field::new("x", DataType::Time32(TimeUnit::Second), true),
1539            Field::new("x", DataType::Timestamp(TimeUnit::Second, None), true),
1540            Field::new("x", DataType::Interval(IntervalUnit::YearMonth), true),
1541            Field::new("x", DataType::Interval(IntervalUnit::DayTime), true),
1542            Field::new("x", DataType::Interval(IntervalUnit::MonthDayNano), true),
1543            Field::new(
1544                "x",
1545                DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
1546                true,
1547            ),
1548            Field::new(
1549                "x",
1550                DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
1551                true,
1552            ),
1553            Field::new(
1554                "x",
1555                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1556                true,
1557            ),
1558            Field::new(
1559                "x",
1560                DataType::List(Arc::new(Field::new(
1561                    "item",
1562                    DataType::Timestamp(TimeUnit::Second, None),
1563                    true,
1564                ))),
1565                true,
1566            ),
1567        ]);
1568
1569        let mut builder = ListBuilder::new(Int64Builder::new());
1570        builder.append_value([Some(1i64), None, Some(2)]);
1571        builder.append_null();
1572        builder.append_value([Some(-1i64), None, Some(-2)]);
1573        let i64_list_array = builder.finish();
1574
1575        let mut builder = ListBuilder::new(Float64Builder::new());
1576        builder.append_value([Some(1.0f64), None, Some(2.0)]);
1577        builder.append_null();
1578        builder.append_value([Some(-1.0f64), None, Some(-2.0)]);
1579        let f64_list_array = builder.finish();
1580
1581        let mut builder = ListBuilder::new(StringBuilder::new());
1582        builder.append_value([Some("a"), None, Some("b")]);
1583        builder.append_null();
1584        builder.append_value([Some("c"), None, Some("d")]);
1585        let string_list_array = builder.finish();
1586
1587        let mut builder = ListBuilder::new(TimestampSecondBuilder::new());
1588        builder.append_value([Some(1i64), None, Some(2)]);
1589        builder.append_null();
1590        builder.append_value([Some(3i64), None, Some(4)]);
1591        let timestamp_list_array = builder.finish();
1592
1593        let values = vec![
1594            Arc::new(NullVector::new(3)) as VectorRef,
1595            Arc::new(BooleanVector::from(vec![Some(true), Some(false), None])),
1596            Arc::new(UInt8Vector::from(vec![Some(u8::MAX), Some(u8::MIN), None])),
1597            Arc::new(UInt16Vector::from(vec![
1598                Some(u16::MAX),
1599                Some(u16::MIN),
1600                None,
1601            ])),
1602            Arc::new(UInt32Vector::from(vec![
1603                Some(u32::MAX),
1604                Some(u32::MIN),
1605                None,
1606            ])),
1607            Arc::new(UInt64Vector::from(vec![
1608                Some(u64::MAX),
1609                Some(u64::MIN),
1610                None,
1611            ])),
1612            Arc::new(Int8Vector::from(vec![Some(i8::MAX), Some(i8::MIN), None])),
1613            Arc::new(Int16Vector::from(vec![
1614                Some(i16::MAX),
1615                Some(i16::MIN),
1616                None,
1617            ])),
1618            Arc::new(Int32Vector::from(vec![
1619                Some(i32::MAX),
1620                Some(i32::MIN),
1621                None,
1622            ])),
1623            Arc::new(Int64Vector::from(vec![
1624                Some(i64::MAX),
1625                Some(i64::MIN),
1626                None,
1627            ])),
1628            Arc::new(Float32Vector::from(vec![
1629                None,
1630                Some(f32::MAX),
1631                Some(f32::MIN),
1632            ])),
1633            Arc::new(Float64Vector::from(vec![
1634                None,
1635                Some(f64::MAX),
1636                Some(f64::MIN),
1637            ])),
1638            Arc::new(StringVector::from(vec![
1639                None,
1640                Some("hello"),
1641                Some("greptime"),
1642            ])),
1643            Arc::new(BinaryVector::from(vec![
1644                None,
1645                Some("hello".as_bytes().to_vec()),
1646                Some("world".as_bytes().to_vec()),
1647            ])),
1648            Arc::new(DateVector::from(vec![Some(1001), None, Some(1)])),
1649            Arc::new(TimeSecondVector::from(vec![Some(1001), None, Some(1)])),
1650            Arc::new(TimestampSecondVector::from(vec![
1651                Some(1000001),
1652                None,
1653                Some(1),
1654            ])),
1655            Arc::new(IntervalYearMonthVector::from(vec![Some(1), None, Some(2)])),
1656            Arc::new(IntervalDayTimeVector::from(vec![
1657                Some(arrow::datatypes::IntervalDayTime::new(1, 1)),
1658                None,
1659                Some(arrow::datatypes::IntervalDayTime::new(2, 2)),
1660            ])),
1661            Arc::new(IntervalMonthDayNanoVector::from(vec![
1662                Some(arrow::datatypes::IntervalMonthDayNano::new(1, 1, 10)),
1663                None,
1664                Some(arrow::datatypes::IntervalMonthDayNano::new(2, 2, 20)),
1665            ])),
1666            Arc::new(ListVector::from(i64_list_array)),
1667            Arc::new(ListVector::from(f64_list_array)),
1668            Arc::new(ListVector::from(string_list_array)),
1669            Arc::new(ListVector::from(timestamp_list_array)),
1670        ];
1671        let record_batch =
1672            RecordBatch::new(Arc::new(arrow_schema.try_into().unwrap()), values).unwrap();
1673
1674        let query_context = QueryContextBuilder::default()
1675            .configuration_parameter(Default::default())
1676            .build()
1677            .into();
1678        let schema = Arc::new(schema);
1679
1680        let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
1681            .filter_map(|x| x.ok())
1682            .collect::<Vec<_>>();
1683        assert_eq!(rows.len(), 3);
1684        for row in rows {
1685            assert_eq!(row.field_count, schema.len() as i16);
1686        }
1687    }
1688
1689    #[test]
1690    fn test_invalid_parameter() {
1691        // test for refactor with PgErrorCode
1692        let msg = "invalid_parameter_count";
1693        let error = invalid_parameter_error(msg, None);
1694        if let PgWireError::UserError(value) = error {
1695            assert_eq!("ERROR", value.severity);
1696            assert_eq!("22023", value.code);
1697            assert_eq!(msg, value.message);
1698        } else {
1699            panic!("test_invalid_parameter failed");
1700        }
1701    }
1702}