1mod bytea;
16mod datetime;
17mod error;
18mod interval;
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, AsArray};
24use arrow::datatypes::{
25 Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType,
26 DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, Int8Type, Int16Type,
27 Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
28 Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
29 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
30 TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
31};
32use arrow_schema::{DataType, IntervalUnit, TimeUnit};
33use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
34use common_decimal::Decimal128;
35use common_recordbatch::RecordBatch;
36use common_time::time::Time;
37use common_time::{
38 Date, Duration, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp,
39};
40use datafusion_common::ScalarValue;
41use datafusion_expr::LogicalPlan;
42use datatypes::arrow::datatypes::DataType as ArrowDataType;
43use datatypes::json::JsonStructureSettings;
44use datatypes::prelude::{ConcreteDataType, Value};
45use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
46use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
47use datatypes::value::StructValue;
48use pgwire::api::Type;
49use pgwire::api::portal::{Format, Portal};
50use pgwire::api::results::{DataRowEncoder, FieldInfo};
51use pgwire::error::{PgWireError, PgWireResult};
52use pgwire::messages::data::DataRow;
53use session::context::QueryContextRef;
54use session::session_config::PGByteaOutputValue;
55use snafu::ResultExt;
56
57use self::bytea::{EscapeOutputBytea, HexOutputBytea};
58use self::datetime::{StylingDate, StylingDateTime};
59pub use self::error::{PgErrorCode, PgErrorSeverity};
60use self::interval::PgInterval;
61use crate::SqlPlan;
62use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
63use crate::postgres::utils::convert_err;
64
65pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
66 origin
67 .column_schemas()
68 .iter()
69 .enumerate()
70 .map(|(idx, col)| {
71 Ok(FieldInfo::new(
72 col.name.clone(),
73 None,
74 None,
75 type_gt_to_pg(&col.data_type)?,
76 field_formats.format_for(idx),
77 ))
78 })
79 .collect::<Result<Vec<FieldInfo>>>()
80}
81
82fn encode_struct(
92 _query_ctx: &QueryContextRef,
93 struct_value: StructValue,
94 builder: &mut DataRowEncoder,
95) -> PgWireResult<()> {
96 let encoding_setting = JsonStructureSettings::Structured(None);
97 let json_value = encoding_setting
98 .decode(Value::Struct(struct_value))
99 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
100
101 builder.encode_field(&json_value)
102}
103
104fn encode_array(
105 query_ctx: &QueryContextRef,
106 array: ArrayRef,
107 builder: &mut DataRowEncoder,
108) -> PgWireResult<()> {
109 macro_rules! encode_primitive_array {
110 ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
111 let array = $array.iter().collect::<Vec<Option<$data_type>>>();
112 if array
113 .iter()
114 .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
115 {
116 builder.encode_field(
117 &array
118 .into_iter()
119 .map(|x| x.map(|i| i as $lower_type))
120 .collect::<Vec<Option<$lower_type>>>(),
121 )
122 } else {
123 builder.encode_field(
124 &array
125 .into_iter()
126 .map(|x| x.map(|i| i as $upper_type))
127 .collect::<Vec<Option<$upper_type>>>(),
128 )
129 }
130 }};
131 }
132
133 match array.data_type() {
134 DataType::Boolean => {
135 let array = array.as_boolean();
136 let array = array.iter().collect::<Vec<_>>();
137 builder.encode_field(&array)
138 }
139 DataType::Int8 => {
140 let array = array.as_primitive::<Int8Type>();
141 let array = array.iter().collect::<Vec<_>>();
142 builder.encode_field(&array)
143 }
144 DataType::Int16 => {
145 let array = array.as_primitive::<Int16Type>();
146 let array = array.iter().collect::<Vec<_>>();
147 builder.encode_field(&array)
148 }
149 DataType::Int32 => {
150 let array = array.as_primitive::<Int32Type>();
151 let array = array.iter().collect::<Vec<_>>();
152 builder.encode_field(&array)
153 }
154 DataType::Int64 => {
155 let array = array.as_primitive::<Int64Type>();
156 let array = array.iter().collect::<Vec<_>>();
157 builder.encode_field(&array)
158 }
159 DataType::UInt8 => {
160 let array = array.as_primitive::<UInt8Type>();
161 encode_primitive_array!(array, u8, i8, i16)
162 }
163 DataType::UInt16 => {
164 let array = array.as_primitive::<UInt16Type>();
165 encode_primitive_array!(array, u16, i16, i32)
166 }
167 DataType::UInt32 => {
168 let array = array.as_primitive::<UInt32Type>();
169 encode_primitive_array!(array, u32, i32, i64)
170 }
171 DataType::UInt64 => {
172 let array = array.as_primitive::<UInt64Type>();
173 let array = array.iter().collect::<Vec<_>>();
174 if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
175 builder.encode_field(
176 &array
177 .into_iter()
178 .map(|x| x.map(|i| i as i64))
179 .collect::<Vec<Option<i64>>>(),
180 )
181 } else {
182 builder.encode_field(
183 &array
184 .into_iter()
185 .map(|x| x.map(|i| i.to_string()))
186 .collect::<Vec<_>>(),
187 )
188 }
189 }
190 DataType::Float32 => {
191 let array = array.as_primitive::<Float32Type>();
192 let array = array.iter().collect::<Vec<_>>();
193 builder.encode_field(&array)
194 }
195 DataType::Float64 => {
196 let array = array.as_primitive::<Float64Type>();
197 let array = array.iter().collect::<Vec<_>>();
198 builder.encode_field(&array)
199 }
200 DataType::Binary => {
201 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
202
203 let array = array.as_binary::<i32>();
204 match *bytea_output {
205 PGByteaOutputValue::ESCAPE => {
206 let array = array
207 .iter()
208 .map(|v| v.map(EscapeOutputBytea))
209 .collect::<Vec<_>>();
210 builder.encode_field(&array)
211 }
212 PGByteaOutputValue::HEX => {
213 let array = array
214 .iter()
215 .map(|v| v.map(HexOutputBytea))
216 .collect::<Vec<_>>();
217 builder.encode_field(&array)
218 }
219 }
220 }
221 DataType::Utf8 => {
222 let array = array.as_string::<i32>();
223 let array = array.into_iter().collect::<Vec<_>>();
224 builder.encode_field(&array)
225 }
226 DataType::LargeUtf8 => {
227 let array = array.as_string::<i64>();
228 let array = array.into_iter().collect::<Vec<_>>();
229 builder.encode_field(&array)
230 }
231 DataType::Utf8View => {
232 let array = array.as_string_view();
233 let array = array.into_iter().collect::<Vec<_>>();
234 builder.encode_field(&array)
235 }
236 DataType::Date32 | DataType::Date64 => {
237 let iter: Box<dyn Iterator<Item = Option<i32>>> =
238 if matches!(array.data_type(), DataType::Date32) {
239 let array = array.as_primitive::<Date32Type>();
240 Box::new(array.into_iter())
241 } else {
242 let array = array.as_primitive::<Date64Type>();
243 Box::new(
247 array
248 .into_iter()
249 .map(|x| x.map(|i| (i / 86_400_000) as i32)),
250 )
251 };
252 let array = iter
253 .into_iter()
254 .map(|v| match v {
255 None => Ok(None),
256 Some(v) => {
257 if let Some(date) = Date::new(v).to_chrono_date() {
258 let (style, order) =
259 *query_ctx.configuration_parameter().pg_datetime_style();
260 Ok(Some(StylingDate(date, style, order)))
261 } else {
262 Err(convert_err(Error::Internal {
263 err_msg: format!("Failed to convert date to postgres type {v:?}",),
264 }))
265 }
266 }
267 })
268 .collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
269 builder.encode_field(&array)
270 }
271 DataType::Timestamp(time_unit, _) => {
272 let array = match time_unit {
273 TimeUnit::Second => {
274 let array = array.as_primitive::<TimestampSecondType>();
275 array.into_iter().collect::<Vec<_>>()
276 }
277 TimeUnit::Millisecond => {
278 let array = array.as_primitive::<TimestampMillisecondType>();
279 array.into_iter().collect::<Vec<_>>()
280 }
281 TimeUnit::Microsecond => {
282 let array = array.as_primitive::<TimestampMicrosecondType>();
283 array.into_iter().collect::<Vec<_>>()
284 }
285 TimeUnit::Nanosecond => {
286 let array = array.as_primitive::<TimestampNanosecondType>();
287 array.into_iter().collect::<Vec<_>>()
288 }
289 };
290 let time_unit = time_unit.into();
291 let array = array
292 .into_iter()
293 .map(|v| match v {
294 None => Ok(None),
295 Some(v) => {
296 let v = Timestamp::new(v, time_unit);
297 if let Some(datetime) =
298 v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
299 {
300 let (style, order) =
301 *query_ctx.configuration_parameter().pg_datetime_style();
302 Ok(Some(StylingDateTime(datetime, style, order)))
303 } else {
304 Err(convert_err(Error::Internal {
305 err_msg: format!("Failed to convert date to postgres type {v:?}",),
306 }))
307 }
308 }
309 })
310 .collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
311 builder.encode_field(&array)
312 }
313 DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
314 let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
315 TimeUnit::Second => {
316 let array = array.as_primitive::<Time32SecondType>();
317 Box::new(
318 array
319 .into_iter()
320 .map(|v| v.map(|i| Time::new_second(i as i64))),
321 )
322 }
323 TimeUnit::Millisecond => {
324 let array = array.as_primitive::<Time32MillisecondType>();
325 Box::new(
326 array
327 .into_iter()
328 .map(|v| v.map(|i| Time::new_millisecond(i as i64))),
329 )
330 }
331 TimeUnit::Microsecond => {
332 let array = array.as_primitive::<Time64MicrosecondType>();
333 Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
334 }
335 TimeUnit::Nanosecond => {
336 let array = array.as_primitive::<Time64NanosecondType>();
337 Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
338 }
339 };
340 let array = iter
341 .into_iter()
342 .map(|v| v.and_then(|v| v.to_chrono_time()))
343 .collect::<Vec<Option<NaiveTime>>>();
344 builder.encode_field(&array)
345 }
346 DataType::Interval(interval_unit) => {
347 let array = match interval_unit {
348 IntervalUnit::YearMonth => {
349 let array = array.as_primitive::<IntervalYearMonthType>();
350 array
351 .into_iter()
352 .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
353 .collect::<Vec<_>>()
354 }
355 IntervalUnit::DayTime => {
356 let array = array.as_primitive::<IntervalDayTimeType>();
357 array
358 .into_iter()
359 .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
360 .collect::<Vec<_>>()
361 }
362 IntervalUnit::MonthDayNano => {
363 let array = array.as_primitive::<IntervalMonthDayNanoType>();
364 array
365 .into_iter()
366 .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
367 .collect::<Vec<_>>()
368 }
369 };
370 builder.encode_field(&array)
371 }
372 DataType::Decimal128(precision, scale) => {
373 let array = array.as_primitive::<Decimal128Type>();
374 let array = array
375 .into_iter()
376 .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
377 .collect::<Vec<_>>();
378 builder.encode_field(&array)
379 }
380 _ => Err(convert_err(Error::Internal {
381 err_msg: format!(
382 "cannot write array type {:?} in postgres protocol: unimplemented",
383 array.data_type()
384 ),
385 })),
386 }
387}
388
389pub(crate) struct RecordBatchRowIterator {
390 query_ctx: QueryContextRef,
391 pg_schema: Arc<Vec<FieldInfo>>,
392 schema: SchemaRef,
393 record_batch: arrow::record_batch::RecordBatch,
394 i: usize,
395}
396
397impl Iterator for RecordBatchRowIterator {
398 type Item = PgWireResult<DataRow>;
399
400 fn next(&mut self) -> Option<Self::Item> {
401 if self.i < self.record_batch.num_rows() {
402 let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
403 if let Err(e) = self.encode_row(self.i, &mut encoder) {
404 return Some(Err(e));
405 }
406 self.i += 1;
407 Some(encoder.finish())
408 } else {
409 None
410 }
411 }
412}
413
414impl RecordBatchRowIterator {
415 pub(crate) fn new(
416 query_ctx: QueryContextRef,
417 pg_schema: Arc<Vec<FieldInfo>>,
418 record_batch: RecordBatch,
419 ) -> Self {
420 let schema = record_batch.schema.clone();
421 let record_batch = record_batch.into_df_record_batch();
422 Self {
423 query_ctx,
424 pg_schema,
425 schema,
426 record_batch,
427 i: 0,
428 }
429 }
430
431 fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
432 for (j, column) in self.record_batch.columns().iter().enumerate() {
433 if column.is_null(i) {
434 encoder.encode_field(&None::<&i8>)?;
435 continue;
436 }
437
438 match column.data_type() {
439 DataType::Null => {
440 encoder.encode_field(&None::<&i8>)?;
441 }
442 DataType::Boolean => {
443 let array = column.as_boolean();
444 encoder.encode_field(&array.value(i))?;
445 }
446 DataType::UInt8 => {
447 let array = column.as_primitive::<UInt8Type>();
448 let value = array.value(i);
449 if value <= i8::MAX as u8 {
450 encoder.encode_field(&(value as i8))?;
451 } else {
452 encoder.encode_field(&(value as i16))?;
453 }
454 }
455 DataType::UInt16 => {
456 let array = column.as_primitive::<UInt16Type>();
457 let value = array.value(i);
458 if value <= i16::MAX as u16 {
459 encoder.encode_field(&(value as i16))?;
460 } else {
461 encoder.encode_field(&(value as i32))?;
462 }
463 }
464 DataType::UInt32 => {
465 let array = column.as_primitive::<UInt32Type>();
466 let value = array.value(i);
467 if value <= i32::MAX as u32 {
468 encoder.encode_field(&(value as i32))?;
469 } else {
470 encoder.encode_field(&(value as i64))?;
471 }
472 }
473 DataType::UInt64 => {
474 let array = column.as_primitive::<UInt64Type>();
475 let value = array.value(i);
476 if value <= i64::MAX as u64 {
477 encoder.encode_field(&(value as i64))?;
478 } else {
479 encoder.encode_field(&value.to_string())?;
480 }
481 }
482 DataType::Int8 => {
483 let array = column.as_primitive::<Int8Type>();
484 encoder.encode_field(&array.value(i))?;
485 }
486 DataType::Int16 => {
487 let array = column.as_primitive::<Int16Type>();
488 encoder.encode_field(&array.value(i))?;
489 }
490 DataType::Int32 => {
491 let array = column.as_primitive::<Int32Type>();
492 encoder.encode_field(&array.value(i))?;
493 }
494 DataType::Int64 => {
495 let array = column.as_primitive::<Int64Type>();
496 encoder.encode_field(&array.value(i))?;
497 }
498 DataType::Float32 => {
499 let array = column.as_primitive::<Float32Type>();
500 encoder.encode_field(&array.value(i))?;
501 }
502 DataType::Float64 => {
503 let array = column.as_primitive::<Float64Type>();
504 encoder.encode_field(&array.value(i))?;
505 }
506 DataType::Utf8 => {
507 let array = column.as_string::<i32>();
508 let value = array.value(i);
509 encoder.encode_field(&value)?;
510 }
511 DataType::Utf8View => {
512 let array = column.as_string_view();
513 let value = array.value(i);
514 encoder.encode_field(&value)?;
515 }
516 DataType::LargeUtf8 => {
517 let array = column.as_string::<i64>();
518 let value = array.value(i);
519 encoder.encode_field(&value)?;
520 }
521 DataType::Binary => {
522 let array = column.as_binary::<i32>();
523 let v = array.value(i);
524 encode_bytes(
525 &self.schema.column_schemas()[j],
526 v,
527 encoder,
528 &self.query_ctx,
529 )?;
530 }
531 DataType::BinaryView => {
532 let array = column.as_binary_view();
533 let v = array.value(i);
534 encode_bytes(
535 &self.schema.column_schemas()[j],
536 v,
537 encoder,
538 &self.query_ctx,
539 )?;
540 }
541 DataType::LargeBinary => {
542 let array = column.as_binary::<i64>();
543 let v = array.value(i);
544 encode_bytes(
545 &self.schema.column_schemas()[j],
546 v,
547 encoder,
548 &self.query_ctx,
549 )?;
550 }
551 DataType::Date32 | DataType::Date64 => {
552 let v = if matches!(column.data_type(), DataType::Date32) {
553 let array = column.as_primitive::<Date32Type>();
554 array.value(i)
555 } else {
556 let array = column.as_primitive::<Date64Type>();
557 (array.value(i) / 86_400_000) as i32
561 };
562 let v = Date::new(v);
563 let date = v.to_chrono_date().map(|v| {
564 let (style, order) =
565 *self.query_ctx.configuration_parameter().pg_datetime_style();
566 StylingDate(v, style, order)
567 });
568 encoder.encode_field(&date)?;
569 }
570 DataType::Timestamp(time_unit, _) => {
571 let v = match time_unit {
572 TimeUnit::Second => {
573 let array = column.as_primitive::<TimestampSecondType>();
574 array.value(i)
575 }
576 TimeUnit::Millisecond => {
577 let array = column.as_primitive::<TimestampMillisecondType>();
578 array.value(i)
579 }
580 TimeUnit::Microsecond => {
581 let array = column.as_primitive::<TimestampMicrosecondType>();
582 array.value(i)
583 }
584 TimeUnit::Nanosecond => {
585 let array = column.as_primitive::<TimestampNanosecondType>();
586 array.value(i)
587 }
588 };
589 let v = Timestamp::new(v, time_unit.into());
590 let datetime = v
591 .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
592 .map(|v| {
593 let (style, order) =
594 *self.query_ctx.configuration_parameter().pg_datetime_style();
595 StylingDateTime(v, style, order)
596 });
597 encoder.encode_field(&datetime)?;
598 }
599 DataType::Interval(interval_unit) => match interval_unit {
600 IntervalUnit::YearMonth => {
601 let array = column.as_primitive::<IntervalYearMonthType>();
602 let v: IntervalYearMonth = array.value(i).into();
603 encoder.encode_field(&PgInterval::from(v))?;
604 }
605 IntervalUnit::DayTime => {
606 let array = column.as_primitive::<IntervalDayTimeType>();
607 let v: IntervalDayTime = array.value(i).into();
608 encoder.encode_field(&PgInterval::from(v))?;
609 }
610 IntervalUnit::MonthDayNano => {
611 let array = column.as_primitive::<IntervalMonthDayNanoType>();
612 let v: IntervalMonthDayNano = array.value(i).into();
613 encoder.encode_field(&PgInterval::from(v))?;
614 }
615 },
616 DataType::Duration(time_unit) => {
617 let v = match time_unit {
618 TimeUnit::Second => {
619 let array = column.as_primitive::<DurationSecondType>();
620 array.value(i)
621 }
622 TimeUnit::Millisecond => {
623 let array = column.as_primitive::<DurationMillisecondType>();
624 array.value(i)
625 }
626 TimeUnit::Microsecond => {
627 let array = column.as_primitive::<DurationMicrosecondType>();
628 array.value(i)
629 }
630 TimeUnit::Nanosecond => {
631 let array = column.as_primitive::<DurationNanosecondType>();
632 array.value(i)
633 }
634 };
635 let d = Duration::new(v, time_unit.into());
636 match PgInterval::try_from(d) {
637 Ok(i) => encoder.encode_field(&i)?,
638 Err(e) => {
639 return Err(convert_err(Error::Internal {
640 err_msg: e.to_string(),
641 }));
642 }
643 }
644 }
645 DataType::List(_) => {
646 let array = column.as_list::<i32>();
647 let items = array.value(i);
648 encode_array(&self.query_ctx, items, encoder)?;
649 }
650 DataType::Struct(_) => {
651 encode_struct(&self.query_ctx, Default::default(), encoder)?;
652 }
653 DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
654 let v = match time_unit {
655 TimeUnit::Second => {
656 let array = column.as_primitive::<Time32SecondType>();
657 Time::new_second(array.value(i) as i64)
658 }
659 TimeUnit::Millisecond => {
660 let array = column.as_primitive::<Time32MillisecondType>();
661 Time::new_millisecond(array.value(i) as i64)
662 }
663 TimeUnit::Microsecond => {
664 let array = column.as_primitive::<Time64MicrosecondType>();
665 Time::new_microsecond(array.value(i))
666 }
667 TimeUnit::Nanosecond => {
668 let array = column.as_primitive::<Time64NanosecondType>();
669 Time::new_nanosecond(array.value(i))
670 }
671 };
672 encoder.encode_field(&v.to_chrono_time())?;
673 }
674 DataType::Decimal128(precision, scale) => {
675 let array = column.as_primitive::<Decimal128Type>();
676 let v = Decimal128::new(array.value(i), *precision, *scale);
677 encoder.encode_field(&v.to_string())?;
678 }
679 _ => {
680 return Err(convert_err(Error::Internal {
681 err_msg: format!(
682 "cannot convert datatype {} to postgres",
683 column.data_type()
684 ),
685 }));
686 }
687 }
688 }
689 Ok(())
690 }
691}
692
693fn encode_bytes(
694 schema: &ColumnSchema,
695 v: &[u8],
696 encoder: &mut DataRowEncoder,
697 query_ctx: &QueryContextRef,
698) -> PgWireResult<()> {
699 if let ConcreteDataType::Json(_) = &schema.data_type {
700 let s = jsonb_to_string(v).map_err(convert_err)?;
701 encoder.encode_field(&s)
702 } else {
703 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
704 match *bytea_output {
705 PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
706 PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
707 }
708 }
709}
710
711pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
712 match origin {
713 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
714 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
715 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
716 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
717 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
718 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
719 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
720 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
721 &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
722 &ConcreteDataType::String(_) => Ok(Type::VARCHAR),
723 &ConcreteDataType::Date(_) => Ok(Type::DATE),
724 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
725 &ConcreteDataType::Time(_) => Ok(Type::TIME),
726 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL),
727 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
728 &ConcreteDataType::Json(_) => Ok(Type::JSON),
729 ConcreteDataType::List(list) => match list.item_type() {
730 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
731 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
732 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
733 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
734 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
735 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
736 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
737 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
738 &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
739 &ConcreteDataType::String(_) => Ok(Type::VARCHAR_ARRAY),
740 &ConcreteDataType::Date(_) => Ok(Type::DATE_ARRAY),
741 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP_ARRAY),
742 &ConcreteDataType::Time(_) => Ok(Type::TIME_ARRAY),
743 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL_ARRAY),
744 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC_ARRAY),
745 &ConcreteDataType::Json(_) => Ok(Type::JSON_ARRAY),
746 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL_ARRAY),
747 &ConcreteDataType::Struct(_) => Ok(Type::JSON_ARRAY),
748 &ConcreteDataType::Dictionary(_)
749 | &ConcreteDataType::Vector(_)
750 | &ConcreteDataType::List(_) => server_error::UnsupportedDataTypeSnafu {
751 data_type: origin,
752 reason: "not implemented",
753 }
754 .fail(),
755 },
756 &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
757 data_type: origin,
758 reason: "not implemented",
759 }
760 .fail(),
761 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL),
762 &ConcreteDataType::Struct(_) => Ok(Type::JSON),
763 }
764}
765
766#[allow(dead_code)]
767pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
768 match origin {
770 &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
771 &Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
772 &Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
773 &Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
774 &Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
775 &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
776 &Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
777 common_time::timestamp::TimeUnit::Millisecond,
778 )),
779 &Type::DATE => Ok(ConcreteDataType::date_datatype()),
780 &Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
781 common_time::timestamp::TimeUnit::Microsecond,
782 )),
783 &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
784 ConcreteDataType::int8_datatype(),
785 ))),
786 &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
787 ConcreteDataType::int16_datatype(),
788 ))),
789 &Type::INT4_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
790 ConcreteDataType::int32_datatype(),
791 ))),
792 &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
793 ConcreteDataType::int64_datatype(),
794 ))),
795 &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
796 ConcreteDataType::string_datatype(),
797 ))),
798 _ => server_error::InternalSnafu {
799 err_msg: format!("unimplemented datatype {origin:?}"),
800 }
801 .fail(),
802 }
803}
804
805pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
806 let param_type = portal.statement.parameter_types.get(idx).unwrap();
809 match param_type {
810 &Type::VARCHAR | &Type::TEXT => Ok(format!(
811 "'{}'",
812 portal
813 .parameter::<String>(idx, param_type)?
814 .as_deref()
815 .unwrap_or("")
816 )),
817 &Type::BOOL => Ok(portal
818 .parameter::<bool>(idx, param_type)?
819 .map(|v| v.to_string())
820 .unwrap_or_else(|| "".to_owned())),
821 &Type::INT4 => Ok(portal
822 .parameter::<i32>(idx, param_type)?
823 .map(|v| v.to_string())
824 .unwrap_or_else(|| "".to_owned())),
825 &Type::INT8 => Ok(portal
826 .parameter::<i64>(idx, param_type)?
827 .map(|v| v.to_string())
828 .unwrap_or_else(|| "".to_owned())),
829 &Type::FLOAT4 => Ok(portal
830 .parameter::<f32>(idx, param_type)?
831 .map(|v| v.to_string())
832 .unwrap_or_else(|| "".to_owned())),
833 &Type::FLOAT8 => Ok(portal
834 .parameter::<f64>(idx, param_type)?
835 .map(|v| v.to_string())
836 .unwrap_or_else(|| "".to_owned())),
837 &Type::DATE => Ok(portal
838 .parameter::<NaiveDate>(idx, param_type)?
839 .map(|v| v.format("%Y-%m-%d").to_string())
840 .unwrap_or_else(|| "".to_owned())),
841 &Type::TIMESTAMP => Ok(portal
842 .parameter::<NaiveDateTime>(idx, param_type)?
843 .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
844 .unwrap_or_else(|| "".to_owned())),
845 &Type::INTERVAL => Ok(portal
846 .parameter::<PgInterval>(idx, param_type)?
847 .map(|v| v.to_string())
848 .unwrap_or_else(|| "".to_owned())),
849 _ => Err(invalid_parameter_error(
850 "unsupported_parameter_type",
851 Some(param_type.to_string()),
852 )),
853 }
854}
855
856pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
857 let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
858 error_info.detail = detail;
859 PgWireError::UserError(Box::new(error_info))
860}
861
862fn to_timestamp_scalar_value<T>(
863 data: Option<T>,
864 unit: &TimestampType,
865 ctype: &ConcreteDataType,
866) -> PgWireResult<ScalarValue>
867where
868 T: Into<i64>,
869{
870 if let Some(n) = data {
871 Value::Timestamp(unit.create_timestamp(n.into()))
872 .try_to_scalar_value(ctype)
873 .map_err(convert_err)
874 } else {
875 Ok(ScalarValue::Null)
876 }
877}
878
879pub(super) fn parameters_to_scalar_values(
880 plan: &LogicalPlan,
881 portal: &Portal<SqlPlan>,
882) -> PgWireResult<Vec<ScalarValue>> {
883 let param_count = portal.parameter_len();
884 let mut results = Vec::with_capacity(param_count);
885
886 let client_param_types = &portal.statement.parameter_types;
887 let param_types = plan
888 .get_parameter_types()
889 .context(DataFusionSnafu)
890 .map_err(convert_err)?
891 .into_iter()
892 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
893 .collect::<HashMap<_, _>>();
894
895 for idx in 0..param_count {
896 let server_type = param_types
897 .get(&format!("${}", idx + 1))
898 .and_then(|t| t.as_ref());
899
900 let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
901 client_given_type.clone()
902 } else if let Some(server_provided_type) = &server_type {
903 type_gt_to_pg(server_provided_type).map_err(convert_err)?
904 } else {
905 return Err(invalid_parameter_error(
906 "unknown_parameter_type",
907 Some(format!(
908 "Cannot get parameter type information for parameter {}",
909 idx
910 )),
911 ));
912 };
913
914 let value = match &client_type {
915 &Type::VARCHAR | &Type::TEXT => {
916 let data = portal.parameter::<String>(idx, &client_type)?;
917 if let Some(server_type) = &server_type {
918 match server_type {
919 ConcreteDataType::String(t) => {
920 if t.is_large() {
921 ScalarValue::LargeUtf8(data)
922 } else {
923 ScalarValue::Utf8(data)
924 }
925 }
926 _ => {
927 return Err(invalid_parameter_error(
928 "invalid_parameter_type",
929 Some(format!("Expected: {}, found: {}", server_type, client_type)),
930 ));
931 }
932 }
933 } else {
934 ScalarValue::Utf8(data)
935 }
936 }
937 &Type::BOOL => {
938 let data = portal.parameter::<bool>(idx, &client_type)?;
939 if let Some(server_type) = &server_type {
940 match server_type {
941 ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
942 _ => {
943 return Err(invalid_parameter_error(
944 "invalid_parameter_type",
945 Some(format!("Expected: {}, found: {}", server_type, client_type)),
946 ));
947 }
948 }
949 } else {
950 ScalarValue::Boolean(data)
951 }
952 }
953 &Type::INT2 => {
954 let data = portal.parameter::<i16>(idx, &client_type)?;
955 if let Some(server_type) = &server_type {
956 match server_type {
957 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
958 ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
959 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
960 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
961 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
962 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
963 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
964 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
965 ConcreteDataType::Timestamp(unit) => {
966 to_timestamp_scalar_value(data, unit, server_type)?
967 }
968 _ => {
969 return Err(invalid_parameter_error(
970 "invalid_parameter_type",
971 Some(format!("Expected: {}, found: {}", server_type, client_type)),
972 ));
973 }
974 }
975 } else {
976 ScalarValue::Int16(data)
977 }
978 }
979 &Type::INT4 => {
980 let data = portal.parameter::<i32>(idx, &client_type)?;
981 if let Some(server_type) = &server_type {
982 match server_type {
983 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
984 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
985 ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
986 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
987 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
988 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
989 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
990 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
991 ConcreteDataType::Timestamp(unit) => {
992 to_timestamp_scalar_value(data, unit, server_type)?
993 }
994 _ => {
995 return Err(invalid_parameter_error(
996 "invalid_parameter_type",
997 Some(format!("Expected: {}, found: {}", server_type, client_type)),
998 ));
999 }
1000 }
1001 } else {
1002 ScalarValue::Int32(data)
1003 }
1004 }
1005 &Type::INT8 => {
1006 let data = portal.parameter::<i64>(idx, &client_type)?;
1007 if let Some(server_type) = &server_type {
1008 match server_type {
1009 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1010 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1011 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1012 ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
1013 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1014 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1015 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1016 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1017 ConcreteDataType::Timestamp(unit) => {
1018 to_timestamp_scalar_value(data, unit, server_type)?
1019 }
1020 _ => {
1021 return Err(invalid_parameter_error(
1022 "invalid_parameter_type",
1023 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1024 ));
1025 }
1026 }
1027 } else {
1028 ScalarValue::Int64(data)
1029 }
1030 }
1031 &Type::FLOAT4 => {
1032 let data = portal.parameter::<f32>(idx, &client_type)?;
1033 if let Some(server_type) = &server_type {
1034 match server_type {
1035 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1036 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1037 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1038 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1039 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1040 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1041 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1042 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1043 ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
1044 ConcreteDataType::Float64(_) => {
1045 ScalarValue::Float64(data.map(|n| n as f64))
1046 }
1047 _ => {
1048 return Err(invalid_parameter_error(
1049 "invalid_parameter_type",
1050 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1051 ));
1052 }
1053 }
1054 } else {
1055 ScalarValue::Float32(data)
1056 }
1057 }
1058 &Type::FLOAT8 => {
1059 let data = portal.parameter::<f64>(idx, &client_type)?;
1060 if let Some(server_type) = &server_type {
1061 match server_type {
1062 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1063 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1064 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1065 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1066 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1067 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1068 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1069 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1070 ConcreteDataType::Float32(_) => {
1071 ScalarValue::Float32(data.map(|n| n as f32))
1072 }
1073 ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
1074 _ => {
1075 return Err(invalid_parameter_error(
1076 "invalid_parameter_type",
1077 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1078 ));
1079 }
1080 }
1081 } else {
1082 ScalarValue::Float64(data)
1083 }
1084 }
1085 &Type::TIMESTAMP => {
1086 let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
1087 if let Some(server_type) = &server_type {
1088 match server_type {
1089 ConcreteDataType::Timestamp(unit) => match *unit {
1090 TimestampType::Second(_) => ScalarValue::TimestampSecond(
1091 data.map(|ts| ts.and_utc().timestamp()),
1092 None,
1093 ),
1094 TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1095 data.map(|ts| ts.and_utc().timestamp_millis()),
1096 None,
1097 ),
1098 TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1099 data.map(|ts| ts.and_utc().timestamp_micros()),
1100 None,
1101 ),
1102 TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1103 data.map(|ts| ts.and_utc().timestamp_micros()),
1104 None,
1105 ),
1106 },
1107 _ => {
1108 return Err(invalid_parameter_error(
1109 "invalid_parameter_type",
1110 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1111 ));
1112 }
1113 }
1114 } else {
1115 ScalarValue::TimestampMillisecond(
1116 data.map(|ts| ts.and_utc().timestamp_millis()),
1117 None,
1118 )
1119 }
1120 }
1121 &Type::DATE => {
1122 let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
1123 if let Some(server_type) = &server_type {
1124 match server_type {
1125 ConcreteDataType::Date(_) => ScalarValue::Date32(
1126 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1127 ),
1128 _ => {
1129 return Err(invalid_parameter_error(
1130 "invalid_parameter_type",
1131 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1132 ));
1133 }
1134 }
1135 } else {
1136 ScalarValue::Date32(
1137 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1138 )
1139 }
1140 }
1141 &Type::INTERVAL => {
1142 let data = portal.parameter::<PgInterval>(idx, &client_type)?;
1143 if let Some(server_type) = &server_type {
1144 match server_type {
1145 ConcreteDataType::Interval(IntervalType::YearMonth(_)) => {
1146 ScalarValue::IntervalYearMonth(
1147 data.map(|i| {
1148 if i.days != 0 || i.microseconds != 0 {
1149 Err(invalid_parameter_error(
1150 "invalid_parameter_type",
1151 Some(format!(
1152 "Expected: {}, found: {}",
1153 server_type, client_type
1154 )),
1155 ))
1156 } else {
1157 Ok(IntervalYearMonth::new(i.months).to_i32())
1158 }
1159 })
1160 .transpose()?,
1161 )
1162 }
1163 ConcreteDataType::Interval(IntervalType::DayTime(_)) => {
1164 ScalarValue::IntervalDayTime(
1165 data.map(|i| {
1166 if i.months != 0 || i.microseconds % 1000 != 0 {
1167 Err(invalid_parameter_error(
1168 "invalid_parameter_type",
1169 Some(format!(
1170 "Expected: {}, found: {}",
1171 server_type, client_type
1172 )),
1173 ))
1174 } else {
1175 Ok(IntervalDayTime::new(
1176 i.days,
1177 (i.microseconds / 1000) as i32,
1178 )
1179 .into())
1180 }
1181 })
1182 .transpose()?,
1183 )
1184 }
1185 ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
1186 ScalarValue::IntervalMonthDayNano(
1187 data.map(|i| IntervalMonthDayNano::from(i).into()),
1188 )
1189 }
1190 _ => {
1191 return Err(invalid_parameter_error(
1192 "invalid_parameter_type",
1193 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1194 ));
1195 }
1196 }
1197 } else {
1198 ScalarValue::IntervalMonthDayNano(
1199 data.map(|i| IntervalMonthDayNano::from(i).into()),
1200 )
1201 }
1202 }
1203 &Type::BYTEA => {
1204 let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
1205 if let Some(server_type) = &server_type {
1206 match server_type {
1207 ConcreteDataType::String(t) => {
1208 let s = data.map(|d| String::from_utf8_lossy(&d).to_string());
1209 if t.is_large() {
1210 ScalarValue::LargeUtf8(s)
1211 } else {
1212 ScalarValue::Utf8(s)
1213 }
1214 }
1215 ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
1216 _ => {
1217 return Err(invalid_parameter_error(
1218 "invalid_parameter_type",
1219 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1220 ));
1221 }
1222 }
1223 } else {
1224 ScalarValue::Binary(data)
1225 }
1226 }
1227 &Type::JSONB => {
1228 let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
1229 if let Some(server_type) = &server_type {
1230 match server_type {
1231 ConcreteDataType::Binary(_) => {
1232 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1233 }
1234 _ => {
1235 return Err(invalid_parameter_error(
1236 "invalid_parameter_type",
1237 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1238 ));
1239 }
1240 }
1241 } else {
1242 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1243 }
1244 }
1245 &Type::INT2_ARRAY => {
1246 let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
1247 if let Some(data) = data {
1248 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1249 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
1250 } else {
1251 ScalarValue::Null
1252 }
1253 }
1254 &Type::INT4_ARRAY => {
1255 let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
1256 if let Some(data) = data {
1257 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1258 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
1259 } else {
1260 ScalarValue::Null
1261 }
1262 }
1263 &Type::INT8_ARRAY => {
1264 let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
1265 if let Some(data) = data {
1266 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1267 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
1268 } else {
1269 ScalarValue::Null
1270 }
1271 }
1272 &Type::VARCHAR_ARRAY => {
1273 let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
1274 if let Some(data) = data {
1275 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1276 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
1277 } else {
1278 ScalarValue::Null
1279 }
1280 }
1281 _ => Err(invalid_parameter_error(
1282 "unsupported_parameter_value",
1283 Some(format!("Found type: {}", client_type)),
1284 ))?,
1285 };
1286
1287 results.push(value);
1288 }
1289
1290 Ok(results)
1291}
1292
1293pub(super) fn param_types_to_pg_types(
1294 param_types: &HashMap<String, Option<ConcreteDataType>>,
1295) -> Result<Vec<Type>> {
1296 let param_count = param_types.len();
1297 let mut types = Vec::with_capacity(param_count);
1298 for i in 0..param_count {
1299 if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
1300 let pg_type = type_gt_to_pg(param_type)?;
1301 types.push(pg_type);
1302 } else {
1303 types.push(Type::UNKNOWN);
1304 }
1305 }
1306 Ok(types)
1307}
1308
1309#[cfg(test)]
1310mod test {
1311 use std::sync::Arc;
1312
1313 use arrow::array::{
1314 Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
1315 };
1316 use arrow_schema::Field;
1317 use datatypes::schema::{ColumnSchema, Schema};
1318 use datatypes::vectors::{
1319 BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
1320 Int16Vector, Int32Vector, Int64Vector, IntervalDayTimeVector, IntervalMonthDayNanoVector,
1321 IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
1322 TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
1323 };
1324 use pgwire::api::Type;
1325 use pgwire::api::results::{FieldFormat, FieldInfo};
1326 use session::context::QueryContextBuilder;
1327
1328 use super::*;
1329
1330 #[test]
1331 fn test_schema_convert() {
1332 let column_schemas = vec![
1333 ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
1334 ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
1335 ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
1336 ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
1337 ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
1338 ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
1339 ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
1340 ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
1341 ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
1342 ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
1343 ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
1344 ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
1345 ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
1346 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1347 ColumnSchema::new(
1348 "timestamps",
1349 ConcreteDataType::timestamp_millisecond_datatype(),
1350 true,
1351 ),
1352 ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
1353 ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
1354 ColumnSchema::new(
1355 "intervals",
1356 ConcreteDataType::interval_month_day_nano_datatype(),
1357 true,
1358 ),
1359 ];
1360 let pg_field_info = vec![
1361 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1362 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1363 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1364 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1365 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1366 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1367 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1368 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1369 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1370 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1371 FieldInfo::new(
1372 "float32s".into(),
1373 None,
1374 None,
1375 Type::FLOAT4,
1376 FieldFormat::Text,
1377 ),
1378 FieldInfo::new(
1379 "float64s".into(),
1380 None,
1381 None,
1382 Type::FLOAT8,
1383 FieldFormat::Text,
1384 ),
1385 FieldInfo::new(
1386 "binaries".into(),
1387 None,
1388 None,
1389 Type::BYTEA,
1390 FieldFormat::Text,
1391 ),
1392 FieldInfo::new(
1393 "strings".into(),
1394 None,
1395 None,
1396 Type::VARCHAR,
1397 FieldFormat::Text,
1398 ),
1399 FieldInfo::new(
1400 "timestamps".into(),
1401 None,
1402 None,
1403 Type::TIMESTAMP,
1404 FieldFormat::Text,
1405 ),
1406 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1407 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1408 FieldInfo::new(
1409 "intervals".into(),
1410 None,
1411 None,
1412 Type::INTERVAL,
1413 FieldFormat::Text,
1414 ),
1415 ];
1416 let schema = Schema::new(column_schemas);
1417 let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
1418 assert_eq!(fs, pg_field_info);
1419 }
1420
1421 #[test]
1422 fn test_encode_text_format_data() {
1423 let schema = vec![
1424 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1425 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1426 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1427 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1428 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1429 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1430 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1431 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1432 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1433 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1434 FieldInfo::new(
1435 "float32s".into(),
1436 None,
1437 None,
1438 Type::FLOAT4,
1439 FieldFormat::Text,
1440 ),
1441 FieldInfo::new(
1442 "float64s".into(),
1443 None,
1444 None,
1445 Type::FLOAT8,
1446 FieldFormat::Text,
1447 ),
1448 FieldInfo::new(
1449 "strings".into(),
1450 None,
1451 None,
1452 Type::VARCHAR,
1453 FieldFormat::Text,
1454 ),
1455 FieldInfo::new(
1456 "binaries".into(),
1457 None,
1458 None,
1459 Type::BYTEA,
1460 FieldFormat::Text,
1461 ),
1462 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1463 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1464 FieldInfo::new(
1465 "timestamps".into(),
1466 None,
1467 None,
1468 Type::TIMESTAMP,
1469 FieldFormat::Text,
1470 ),
1471 FieldInfo::new(
1472 "interval_year_month".into(),
1473 None,
1474 None,
1475 Type::INTERVAL,
1476 FieldFormat::Text,
1477 ),
1478 FieldInfo::new(
1479 "interval_day_time".into(),
1480 None,
1481 None,
1482 Type::INTERVAL,
1483 FieldFormat::Text,
1484 ),
1485 FieldInfo::new(
1486 "interval_month_day_nano".into(),
1487 None,
1488 None,
1489 Type::INTERVAL,
1490 FieldFormat::Text,
1491 ),
1492 FieldInfo::new(
1493 "int_list".into(),
1494 None,
1495 None,
1496 Type::INT8_ARRAY,
1497 FieldFormat::Text,
1498 ),
1499 FieldInfo::new(
1500 "float_list".into(),
1501 None,
1502 None,
1503 Type::FLOAT8_ARRAY,
1504 FieldFormat::Text,
1505 ),
1506 FieldInfo::new(
1507 "string_list".into(),
1508 None,
1509 None,
1510 Type::VARCHAR_ARRAY,
1511 FieldFormat::Text,
1512 ),
1513 FieldInfo::new(
1514 "timestamp_list".into(),
1515 None,
1516 None,
1517 Type::TIMESTAMP_ARRAY,
1518 FieldFormat::Text,
1519 ),
1520 ];
1521
1522 let arrow_schema = arrow_schema::Schema::new(vec![
1523 Field::new("x", DataType::Null, true),
1524 Field::new("x", DataType::Boolean, true),
1525 Field::new("x", DataType::UInt8, true),
1526 Field::new("x", DataType::UInt16, true),
1527 Field::new("x", DataType::UInt32, true),
1528 Field::new("x", DataType::UInt64, true),
1529 Field::new("x", DataType::Int8, true),
1530 Field::new("x", DataType::Int16, true),
1531 Field::new("x", DataType::Int32, true),
1532 Field::new("x", DataType::Int64, true),
1533 Field::new("x", DataType::Float32, true),
1534 Field::new("x", DataType::Float64, true),
1535 Field::new("x", DataType::Utf8, true),
1536 Field::new("x", DataType::Binary, true),
1537 Field::new("x", DataType::Date32, true),
1538 Field::new("x", DataType::Time32(TimeUnit::Second), true),
1539 Field::new("x", DataType::Timestamp(TimeUnit::Second, None), true),
1540 Field::new("x", DataType::Interval(IntervalUnit::YearMonth), true),
1541 Field::new("x", DataType::Interval(IntervalUnit::DayTime), true),
1542 Field::new("x", DataType::Interval(IntervalUnit::MonthDayNano), true),
1543 Field::new(
1544 "x",
1545 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
1546 true,
1547 ),
1548 Field::new(
1549 "x",
1550 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
1551 true,
1552 ),
1553 Field::new(
1554 "x",
1555 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1556 true,
1557 ),
1558 Field::new(
1559 "x",
1560 DataType::List(Arc::new(Field::new(
1561 "item",
1562 DataType::Timestamp(TimeUnit::Second, None),
1563 true,
1564 ))),
1565 true,
1566 ),
1567 ]);
1568
1569 let mut builder = ListBuilder::new(Int64Builder::new());
1570 builder.append_value([Some(1i64), None, Some(2)]);
1571 builder.append_null();
1572 builder.append_value([Some(-1i64), None, Some(-2)]);
1573 let i64_list_array = builder.finish();
1574
1575 let mut builder = ListBuilder::new(Float64Builder::new());
1576 builder.append_value([Some(1.0f64), None, Some(2.0)]);
1577 builder.append_null();
1578 builder.append_value([Some(-1.0f64), None, Some(-2.0)]);
1579 let f64_list_array = builder.finish();
1580
1581 let mut builder = ListBuilder::new(StringBuilder::new());
1582 builder.append_value([Some("a"), None, Some("b")]);
1583 builder.append_null();
1584 builder.append_value([Some("c"), None, Some("d")]);
1585 let string_list_array = builder.finish();
1586
1587 let mut builder = ListBuilder::new(TimestampSecondBuilder::new());
1588 builder.append_value([Some(1i64), None, Some(2)]);
1589 builder.append_null();
1590 builder.append_value([Some(3i64), None, Some(4)]);
1591 let timestamp_list_array = builder.finish();
1592
1593 let values = vec![
1594 Arc::new(NullVector::new(3)) as VectorRef,
1595 Arc::new(BooleanVector::from(vec![Some(true), Some(false), None])),
1596 Arc::new(UInt8Vector::from(vec![Some(u8::MAX), Some(u8::MIN), None])),
1597 Arc::new(UInt16Vector::from(vec![
1598 Some(u16::MAX),
1599 Some(u16::MIN),
1600 None,
1601 ])),
1602 Arc::new(UInt32Vector::from(vec![
1603 Some(u32::MAX),
1604 Some(u32::MIN),
1605 None,
1606 ])),
1607 Arc::new(UInt64Vector::from(vec![
1608 Some(u64::MAX),
1609 Some(u64::MIN),
1610 None,
1611 ])),
1612 Arc::new(Int8Vector::from(vec![Some(i8::MAX), Some(i8::MIN), None])),
1613 Arc::new(Int16Vector::from(vec![
1614 Some(i16::MAX),
1615 Some(i16::MIN),
1616 None,
1617 ])),
1618 Arc::new(Int32Vector::from(vec![
1619 Some(i32::MAX),
1620 Some(i32::MIN),
1621 None,
1622 ])),
1623 Arc::new(Int64Vector::from(vec![
1624 Some(i64::MAX),
1625 Some(i64::MIN),
1626 None,
1627 ])),
1628 Arc::new(Float32Vector::from(vec![
1629 None,
1630 Some(f32::MAX),
1631 Some(f32::MIN),
1632 ])),
1633 Arc::new(Float64Vector::from(vec![
1634 None,
1635 Some(f64::MAX),
1636 Some(f64::MIN),
1637 ])),
1638 Arc::new(StringVector::from(vec![
1639 None,
1640 Some("hello"),
1641 Some("greptime"),
1642 ])),
1643 Arc::new(BinaryVector::from(vec![
1644 None,
1645 Some("hello".as_bytes().to_vec()),
1646 Some("world".as_bytes().to_vec()),
1647 ])),
1648 Arc::new(DateVector::from(vec![Some(1001), None, Some(1)])),
1649 Arc::new(TimeSecondVector::from(vec![Some(1001), None, Some(1)])),
1650 Arc::new(TimestampSecondVector::from(vec![
1651 Some(1000001),
1652 None,
1653 Some(1),
1654 ])),
1655 Arc::new(IntervalYearMonthVector::from(vec![Some(1), None, Some(2)])),
1656 Arc::new(IntervalDayTimeVector::from(vec![
1657 Some(arrow::datatypes::IntervalDayTime::new(1, 1)),
1658 None,
1659 Some(arrow::datatypes::IntervalDayTime::new(2, 2)),
1660 ])),
1661 Arc::new(IntervalMonthDayNanoVector::from(vec![
1662 Some(arrow::datatypes::IntervalMonthDayNano::new(1, 1, 10)),
1663 None,
1664 Some(arrow::datatypes::IntervalMonthDayNano::new(2, 2, 20)),
1665 ])),
1666 Arc::new(ListVector::from(i64_list_array)),
1667 Arc::new(ListVector::from(f64_list_array)),
1668 Arc::new(ListVector::from(string_list_array)),
1669 Arc::new(ListVector::from(timestamp_list_array)),
1670 ];
1671 let record_batch =
1672 RecordBatch::new(Arc::new(arrow_schema.try_into().unwrap()), values).unwrap();
1673
1674 let query_context = QueryContextBuilder::default()
1675 .configuration_parameter(Default::default())
1676 .build()
1677 .into();
1678 let schema = Arc::new(schema);
1679
1680 let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
1681 .filter_map(|x| x.ok())
1682 .collect::<Vec<_>>();
1683 assert_eq!(rows.len(), 3);
1684 for row in rows {
1685 assert_eq!(row.field_count, schema.len() as i16);
1686 }
1687 }
1688
1689 #[test]
1690 fn test_invalid_parameter() {
1691 let msg = "invalid_parameter_count";
1693 let error = invalid_parameter_error(msg, None);
1694 if let PgWireError::UserError(value) = error {
1695 assert_eq!("ERROR", value.severity);
1696 assert_eq!("22023", value.code);
1697 assert_eq!(msg, value.message);
1698 } else {
1699 panic!("test_invalid_parameter failed");
1700 }
1701 }
1702}