1mod bytea;
16mod datetime;
17mod error;
18mod interval;
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, AsArray};
24use arrow::datatypes::{
25 Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type,
26 Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
27 Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
29 TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30};
31use arrow_schema::{DataType, IntervalUnit, TimeUnit};
32use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
33use common_decimal::Decimal128;
34use common_recordbatch::RecordBatch;
35use common_time::time::Time;
36use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp};
37use datafusion_common::ScalarValue;
38use datafusion_expr::LogicalPlan;
39use datatypes::arrow::datatypes::DataType as ArrowDataType;
40use datatypes::json::JsonStructureSettings;
41use datatypes::prelude::{ConcreteDataType, Value};
42use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
43use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
44use datatypes::value::StructValue;
45use pgwire::api::Type;
46use pgwire::api::portal::{Format, Portal};
47use pgwire::api::results::{DataRowEncoder, FieldInfo};
48use pgwire::error::{PgWireError, PgWireResult};
49use pgwire::messages::data::DataRow;
50use session::context::QueryContextRef;
51use session::session_config::PGByteaOutputValue;
52use snafu::ResultExt;
53
54use self::bytea::{EscapeOutputBytea, HexOutputBytea};
55use self::datetime::{StylingDate, StylingDateTime};
56pub use self::error::{PgErrorCode, PgErrorSeverity};
57use self::interval::PgInterval;
58use crate::SqlPlan;
59use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
60use crate::postgres::utils::convert_err;
61
62pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
63 origin
64 .column_schemas()
65 .iter()
66 .enumerate()
67 .map(|(idx, col)| {
68 Ok(FieldInfo::new(
69 col.name.clone(),
70 None,
71 None,
72 type_gt_to_pg(&col.data_type)?,
73 field_formats.format_for(idx),
74 ))
75 })
76 .collect::<Result<Vec<FieldInfo>>>()
77}
78
79fn encode_struct(
89 _query_ctx: &QueryContextRef,
90 struct_value: StructValue,
91 builder: &mut DataRowEncoder,
92) -> PgWireResult<()> {
93 let encoding_setting = JsonStructureSettings::Structured(None);
94 let json_value = encoding_setting
95 .decode(Value::Struct(struct_value))
96 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
97
98 builder.encode_field(&json_value)
99}
100
101fn encode_array(
102 query_ctx: &QueryContextRef,
103 array: ArrayRef,
104 builder: &mut DataRowEncoder,
105) -> PgWireResult<()> {
106 macro_rules! encode_primitive_array {
107 ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
108 let array = $array.iter().collect::<Vec<Option<$data_type>>>();
109 if array
110 .iter()
111 .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
112 {
113 builder.encode_field(
114 &array
115 .into_iter()
116 .map(|x| x.map(|i| i as $lower_type))
117 .collect::<Vec<Option<$lower_type>>>(),
118 )
119 } else {
120 builder.encode_field(
121 &array
122 .into_iter()
123 .map(|x| x.map(|i| i as $upper_type))
124 .collect::<Vec<Option<$upper_type>>>(),
125 )
126 }
127 }};
128 }
129
130 match array.data_type() {
131 DataType::Boolean => {
132 let array = array.as_boolean();
133 let array = array.iter().collect::<Vec<_>>();
134 builder.encode_field(&array)
135 }
136 DataType::Int8 => {
137 let array = array.as_primitive::<Int8Type>();
138 let array = array.iter().collect::<Vec<_>>();
139 builder.encode_field(&array)
140 }
141 DataType::Int16 => {
142 let array = array.as_primitive::<Int16Type>();
143 let array = array.iter().collect::<Vec<_>>();
144 builder.encode_field(&array)
145 }
146 DataType::Int32 => {
147 let array = array.as_primitive::<Int32Type>();
148 let array = array.iter().collect::<Vec<_>>();
149 builder.encode_field(&array)
150 }
151 DataType::Int64 => {
152 let array = array.as_primitive::<Int64Type>();
153 let array = array.iter().collect::<Vec<_>>();
154 builder.encode_field(&array)
155 }
156 DataType::UInt8 => {
157 let array = array.as_primitive::<UInt8Type>();
158 encode_primitive_array!(array, u8, i8, i16)
159 }
160 DataType::UInt16 => {
161 let array = array.as_primitive::<UInt16Type>();
162 encode_primitive_array!(array, u16, i16, i32)
163 }
164 DataType::UInt32 => {
165 let array = array.as_primitive::<UInt32Type>();
166 encode_primitive_array!(array, u32, i32, i64)
167 }
168 DataType::UInt64 => {
169 let array = array.as_primitive::<UInt64Type>();
170 let array = array.iter().collect::<Vec<_>>();
171 if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
172 builder.encode_field(
173 &array
174 .into_iter()
175 .map(|x| x.map(|i| i as i64))
176 .collect::<Vec<Option<i64>>>(),
177 )
178 } else {
179 builder.encode_field(
180 &array
181 .into_iter()
182 .map(|x| x.map(|i| i.to_string()))
183 .collect::<Vec<_>>(),
184 )
185 }
186 }
187 DataType::Float32 => {
188 let array = array.as_primitive::<Float32Type>();
189 let array = array.iter().collect::<Vec<_>>();
190 builder.encode_field(&array)
191 }
192 DataType::Float64 => {
193 let array = array.as_primitive::<Float64Type>();
194 let array = array.iter().collect::<Vec<_>>();
195 builder.encode_field(&array)
196 }
197 DataType::Binary => {
198 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
199
200 let array = array.as_binary::<i32>();
201 match *bytea_output {
202 PGByteaOutputValue::ESCAPE => {
203 let array = array
204 .iter()
205 .map(|v| v.map(EscapeOutputBytea))
206 .collect::<Vec<_>>();
207 builder.encode_field(&array)
208 }
209 PGByteaOutputValue::HEX => {
210 let array = array
211 .iter()
212 .map(|v| v.map(HexOutputBytea))
213 .collect::<Vec<_>>();
214 builder.encode_field(&array)
215 }
216 }
217 }
218 DataType::Utf8 => {
219 let array = array.as_string::<i32>();
220 let array = array.into_iter().collect::<Vec<_>>();
221 builder.encode_field(&array)
222 }
223 DataType::LargeUtf8 => {
224 let array = array.as_string::<i64>();
225 let array = array.into_iter().collect::<Vec<_>>();
226 builder.encode_field(&array)
227 }
228 DataType::Utf8View => {
229 let array = array.as_string_view();
230 let array = array.into_iter().collect::<Vec<_>>();
231 builder.encode_field(&array)
232 }
233 DataType::Date32 | DataType::Date64 => {
234 let iter: Box<dyn Iterator<Item = Option<i32>>> =
235 if matches!(array.data_type(), DataType::Date32) {
236 let array = array.as_primitive::<Date32Type>();
237 Box::new(array.into_iter())
238 } else {
239 let array = array.as_primitive::<Date64Type>();
240 Box::new(
244 array
245 .into_iter()
246 .map(|x| x.map(|i| (i / 86_400_000) as i32)),
247 )
248 };
249 let array = iter
250 .into_iter()
251 .map(|v| match v {
252 None => Ok(None),
253 Some(v) => {
254 if let Some(date) = Date::new(v).to_chrono_date() {
255 let (style, order) =
256 *query_ctx.configuration_parameter().pg_datetime_style();
257 Ok(Some(StylingDate(date, style, order)))
258 } else {
259 Err(convert_err(Error::Internal {
260 err_msg: format!("Failed to convert date to postgres type {v:?}",),
261 }))
262 }
263 }
264 })
265 .collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
266 builder.encode_field(&array)
267 }
268 DataType::Timestamp(time_unit, _) => {
269 let array = match time_unit {
270 TimeUnit::Second => {
271 let array = array.as_primitive::<TimestampSecondType>();
272 array.into_iter().collect::<Vec<_>>()
273 }
274 TimeUnit::Millisecond => {
275 let array = array.as_primitive::<TimestampMillisecondType>();
276 array.into_iter().collect::<Vec<_>>()
277 }
278 TimeUnit::Microsecond => {
279 let array = array.as_primitive::<TimestampMicrosecondType>();
280 array.into_iter().collect::<Vec<_>>()
281 }
282 TimeUnit::Nanosecond => {
283 let array = array.as_primitive::<TimestampNanosecondType>();
284 array.into_iter().collect::<Vec<_>>()
285 }
286 };
287 let time_unit = time_unit.into();
288 let array = array
289 .into_iter()
290 .map(|v| match v {
291 None => Ok(None),
292 Some(v) => {
293 let v = Timestamp::new(v, time_unit);
294 if let Some(datetime) =
295 v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
296 {
297 let (style, order) =
298 *query_ctx.configuration_parameter().pg_datetime_style();
299 Ok(Some(StylingDateTime(datetime, style, order)))
300 } else {
301 Err(convert_err(Error::Internal {
302 err_msg: format!("Failed to convert date to postgres type {v:?}",),
303 }))
304 }
305 }
306 })
307 .collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
308 builder.encode_field(&array)
309 }
310 DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
311 let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
312 TimeUnit::Second => {
313 let array = array.as_primitive::<Time32SecondType>();
314 Box::new(
315 array
316 .into_iter()
317 .map(|v| v.map(|i| Time::new_second(i as i64))),
318 )
319 }
320 TimeUnit::Millisecond => {
321 let array = array.as_primitive::<Time32MillisecondType>();
322 Box::new(
323 array
324 .into_iter()
325 .map(|v| v.map(|i| Time::new_millisecond(i as i64))),
326 )
327 }
328 TimeUnit::Microsecond => {
329 let array = array.as_primitive::<Time64MicrosecondType>();
330 Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
331 }
332 TimeUnit::Nanosecond => {
333 let array = array.as_primitive::<Time64NanosecondType>();
334 Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
335 }
336 };
337 let array = iter
338 .into_iter()
339 .map(|v| v.and_then(|v| v.to_chrono_time()))
340 .collect::<Vec<Option<NaiveTime>>>();
341 builder.encode_field(&array)
342 }
343 DataType::Interval(interval_unit) => {
344 let array = match interval_unit {
345 IntervalUnit::YearMonth => {
346 let array = array.as_primitive::<IntervalYearMonthType>();
347 array
348 .into_iter()
349 .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
350 .collect::<Vec<_>>()
351 }
352 IntervalUnit::DayTime => {
353 let array = array.as_primitive::<IntervalDayTimeType>();
354 array
355 .into_iter()
356 .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
357 .collect::<Vec<_>>()
358 }
359 IntervalUnit::MonthDayNano => {
360 let array = array.as_primitive::<IntervalMonthDayNanoType>();
361 array
362 .into_iter()
363 .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
364 .collect::<Vec<_>>()
365 }
366 };
367 builder.encode_field(&array)
368 }
369 DataType::Decimal128(precision, scale) => {
370 let array = array.as_primitive::<Decimal128Type>();
371 let array = array
372 .into_iter()
373 .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
374 .collect::<Vec<_>>();
375 builder.encode_field(&array)
376 }
377 _ => Err(convert_err(Error::Internal {
378 err_msg: format!(
379 "cannot write array type {:?} in postgres protocol: unimplemented",
380 array.data_type()
381 ),
382 })),
383 }
384}
385
386pub(crate) struct RecordBatchRowIterator {
387 query_ctx: QueryContextRef,
388 pg_schema: Arc<Vec<FieldInfo>>,
389 schema: SchemaRef,
390 record_batch: arrow::record_batch::RecordBatch,
391 i: usize,
392}
393
394impl Iterator for RecordBatchRowIterator {
395 type Item = PgWireResult<DataRow>;
396
397 fn next(&mut self) -> Option<Self::Item> {
398 if self.i < self.record_batch.num_rows() {
399 let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
400 if let Err(e) = self.encode_row(self.i, &mut encoder) {
401 return Some(Err(e));
402 }
403 self.i += 1;
404 Some(encoder.finish())
405 } else {
406 None
407 }
408 }
409}
410
411impl RecordBatchRowIterator {
412 pub(crate) fn new(
413 query_ctx: QueryContextRef,
414 pg_schema: Arc<Vec<FieldInfo>>,
415 record_batch: RecordBatch,
416 ) -> Self {
417 let schema = record_batch.schema.clone();
418 let record_batch = record_batch.into_df_record_batch();
419 Self {
420 query_ctx,
421 pg_schema,
422 schema,
423 record_batch,
424 i: 0,
425 }
426 }
427
428 fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
429 for (j, column) in self.record_batch.columns().iter().enumerate() {
430 if column.is_null(i) {
431 encoder.encode_field(&None::<&i8>)?;
432 continue;
433 }
434
435 match column.data_type() {
436 DataType::Null => {
437 encoder.encode_field(&None::<&i8>)?;
438 }
439 DataType::Boolean => {
440 let array = column.as_boolean();
441 encoder.encode_field(&array.value(i))?;
442 }
443 DataType::UInt8 => {
444 let array = column.as_primitive::<UInt8Type>();
445 let value = array.value(i);
446 if value <= i8::MAX as u8 {
447 encoder.encode_field(&(value as i8))?;
448 } else {
449 encoder.encode_field(&(value as i16))?;
450 }
451 }
452 DataType::UInt16 => {
453 let array = column.as_primitive::<UInt16Type>();
454 let value = array.value(i);
455 if value <= i16::MAX as u16 {
456 encoder.encode_field(&(value as i16))?;
457 } else {
458 encoder.encode_field(&(value as i32))?;
459 }
460 }
461 DataType::UInt32 => {
462 let array = column.as_primitive::<UInt32Type>();
463 let value = array.value(i);
464 if value <= i32::MAX as u32 {
465 encoder.encode_field(&(value as i32))?;
466 } else {
467 encoder.encode_field(&(value as i64))?;
468 }
469 }
470 DataType::UInt64 => {
471 let array = column.as_primitive::<UInt64Type>();
472 let value = array.value(i);
473 if value <= i64::MAX as u64 {
474 encoder.encode_field(&(value as i64))?;
475 } else {
476 encoder.encode_field(&value.to_string())?;
477 }
478 }
479 DataType::Int8 => {
480 let array = column.as_primitive::<Int8Type>();
481 encoder.encode_field(&array.value(i))?;
482 }
483 DataType::Int16 => {
484 let array = column.as_primitive::<Int16Type>();
485 encoder.encode_field(&array.value(i))?;
486 }
487 DataType::Int32 => {
488 let array = column.as_primitive::<Int32Type>();
489 encoder.encode_field(&array.value(i))?;
490 }
491 DataType::Int64 => {
492 let array = column.as_primitive::<Int64Type>();
493 encoder.encode_field(&array.value(i))?;
494 }
495 DataType::Float32 => {
496 let array = column.as_primitive::<Float32Type>();
497 encoder.encode_field(&array.value(i))?;
498 }
499 DataType::Float64 => {
500 let array = column.as_primitive::<Float64Type>();
501 encoder.encode_field(&array.value(i))?;
502 }
503 DataType::Utf8 => {
504 let array = column.as_string::<i32>();
505 let value = array.value(i);
506 encoder.encode_field(&value)?;
507 }
508 DataType::Utf8View => {
509 let array = column.as_string_view();
510 let value = array.value(i);
511 encoder.encode_field(&value)?;
512 }
513 DataType::LargeUtf8 => {
514 let array = column.as_string::<i64>();
515 let value = array.value(i);
516 encoder.encode_field(&value)?;
517 }
518 DataType::Binary => {
519 let array = column.as_binary::<i32>();
520 let v = array.value(i);
521 encode_bytes(
522 &self.schema.column_schemas()[j],
523 v,
524 encoder,
525 &self.query_ctx,
526 )?;
527 }
528 DataType::BinaryView => {
529 let array = column.as_binary_view();
530 let v = array.value(i);
531 encode_bytes(
532 &self.schema.column_schemas()[j],
533 v,
534 encoder,
535 &self.query_ctx,
536 )?;
537 }
538 DataType::LargeBinary => {
539 let array = column.as_binary::<i64>();
540 let v = array.value(i);
541 encode_bytes(
542 &self.schema.column_schemas()[j],
543 v,
544 encoder,
545 &self.query_ctx,
546 )?;
547 }
548 DataType::Date32 | DataType::Date64 => {
549 let v = if matches!(column.data_type(), DataType::Date32) {
550 let array = column.as_primitive::<Date32Type>();
551 array.value(i)
552 } else {
553 let array = column.as_primitive::<Date64Type>();
554 (array.value(i) / 86_400_000) as i32
558 };
559 let v = Date::new(v);
560 let date = v.to_chrono_date().map(|v| {
561 let (style, order) =
562 *self.query_ctx.configuration_parameter().pg_datetime_style();
563 StylingDate(v, style, order)
564 });
565 encoder.encode_field(&date)?;
566 }
567 DataType::Timestamp(_, _) => {
568 let v = datatypes::arrow_array::timestamp_array_value(column, i);
569 let datetime = v
570 .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
571 .map(|v| {
572 let (style, order) =
573 *self.query_ctx.configuration_parameter().pg_datetime_style();
574 StylingDateTime(v, style, order)
575 });
576 encoder.encode_field(&datetime)?;
577 }
578 DataType::Interval(interval_unit) => match interval_unit {
579 IntervalUnit::YearMonth => {
580 let array = column.as_primitive::<IntervalYearMonthType>();
581 let v: IntervalYearMonth = array.value(i).into();
582 encoder.encode_field(&PgInterval::from(v))?;
583 }
584 IntervalUnit::DayTime => {
585 let array = column.as_primitive::<IntervalDayTimeType>();
586 let v: IntervalDayTime = array.value(i).into();
587 encoder.encode_field(&PgInterval::from(v))?;
588 }
589 IntervalUnit::MonthDayNano => {
590 let array = column.as_primitive::<IntervalMonthDayNanoType>();
591 let v: IntervalMonthDayNano = array.value(i).into();
592 encoder.encode_field(&PgInterval::from(v))?;
593 }
594 },
595 DataType::Duration(_) => {
596 let d = datatypes::arrow_array::duration_array_value(column, i);
597 match PgInterval::try_from(d) {
598 Ok(i) => encoder.encode_field(&i)?,
599 Err(e) => {
600 return Err(convert_err(Error::Internal {
601 err_msg: e.to_string(),
602 }));
603 }
604 }
605 }
606 DataType::List(_) => {
607 let array = column.as_list::<i32>();
608 let items = array.value(i);
609 encode_array(&self.query_ctx, items, encoder)?;
610 }
611 DataType::Struct(_) => {
612 encode_struct(&self.query_ctx, Default::default(), encoder)?;
613 }
614 DataType::Time32(_) | DataType::Time64(_) => {
615 let v = datatypes::arrow_array::time_array_value(column, i);
616 encoder.encode_field(&v.to_chrono_time())?;
617 }
618 DataType::Decimal128(precision, scale) => {
619 let array = column.as_primitive::<Decimal128Type>();
620 let v = Decimal128::new(array.value(i), *precision, *scale);
621 encoder.encode_field(&v.to_string())?;
622 }
623 _ => {
624 return Err(convert_err(Error::Internal {
625 err_msg: format!(
626 "cannot convert datatype {} to postgres",
627 column.data_type()
628 ),
629 }));
630 }
631 }
632 }
633 Ok(())
634 }
635}
636
637fn encode_bytes(
638 schema: &ColumnSchema,
639 v: &[u8],
640 encoder: &mut DataRowEncoder,
641 query_ctx: &QueryContextRef,
642) -> PgWireResult<()> {
643 if let ConcreteDataType::Json(_) = &schema.data_type {
644 let s = jsonb_to_string(v).map_err(convert_err)?;
645 encoder.encode_field(&s)
646 } else {
647 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
648 match *bytea_output {
649 PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
650 PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
651 }
652 }
653}
654
655pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
656 match origin {
657 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
658 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
659 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
660 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
661 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
662 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
663 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
664 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
665 &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
666 &ConcreteDataType::String(_) => Ok(Type::VARCHAR),
667 &ConcreteDataType::Date(_) => Ok(Type::DATE),
668 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
669 &ConcreteDataType::Time(_) => Ok(Type::TIME),
670 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL),
671 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
672 &ConcreteDataType::Json(_) => Ok(Type::JSON),
673 ConcreteDataType::List(list) => match list.item_type() {
674 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
675 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
676 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
677 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
678 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
679 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
680 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
681 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
682 &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
683 &ConcreteDataType::String(_) => Ok(Type::VARCHAR_ARRAY),
684 &ConcreteDataType::Date(_) => Ok(Type::DATE_ARRAY),
685 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP_ARRAY),
686 &ConcreteDataType::Time(_) => Ok(Type::TIME_ARRAY),
687 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL_ARRAY),
688 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC_ARRAY),
689 &ConcreteDataType::Json(_) => Ok(Type::JSON_ARRAY),
690 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL_ARRAY),
691 &ConcreteDataType::Struct(_) => Ok(Type::JSON_ARRAY),
692 &ConcreteDataType::Dictionary(_)
693 | &ConcreteDataType::Vector(_)
694 | &ConcreteDataType::List(_) => server_error::UnsupportedDataTypeSnafu {
695 data_type: origin,
696 reason: "not implemented",
697 }
698 .fail(),
699 },
700 &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
701 data_type: origin,
702 reason: "not implemented",
703 }
704 .fail(),
705 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL),
706 &ConcreteDataType::Struct(_) => Ok(Type::JSON),
707 }
708}
709
710#[allow(dead_code)]
711pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
712 match origin {
714 &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
715 &Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
716 &Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
717 &Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
718 &Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
719 &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
720 &Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype(
721 common_time::timestamp::TimeUnit::Millisecond,
722 )),
723 &Type::DATE => Ok(ConcreteDataType::date_datatype()),
724 &Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
725 common_time::timestamp::TimeUnit::Microsecond,
726 )),
727 &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
728 ConcreteDataType::int8_datatype(),
729 ))),
730 &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
731 ConcreteDataType::int16_datatype(),
732 ))),
733 &Type::INT4_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
734 ConcreteDataType::int32_datatype(),
735 ))),
736 &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
737 ConcreteDataType::int64_datatype(),
738 ))),
739 &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
740 ConcreteDataType::string_datatype(),
741 ))),
742 _ => server_error::InternalSnafu {
743 err_msg: format!("unimplemented datatype {origin:?}"),
744 }
745 .fail(),
746 }
747}
748
749pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
750 let param_type = portal
753 .statement
754 .parameter_types
755 .get(idx)
756 .unwrap()
757 .as_ref()
758 .unwrap_or(&Type::UNKNOWN);
759 match param_type {
760 &Type::VARCHAR | &Type::TEXT => Ok(format!(
761 "'{}'",
762 portal
763 .parameter::<String>(idx, param_type)?
764 .as_deref()
765 .unwrap_or("")
766 )),
767 &Type::BOOL => Ok(portal
768 .parameter::<bool>(idx, param_type)?
769 .map(|v| v.to_string())
770 .unwrap_or_else(|| "".to_owned())),
771 &Type::INT4 => Ok(portal
772 .parameter::<i32>(idx, param_type)?
773 .map(|v| v.to_string())
774 .unwrap_or_else(|| "".to_owned())),
775 &Type::INT8 => Ok(portal
776 .parameter::<i64>(idx, param_type)?
777 .map(|v| v.to_string())
778 .unwrap_or_else(|| "".to_owned())),
779 &Type::FLOAT4 => Ok(portal
780 .parameter::<f32>(idx, param_type)?
781 .map(|v| v.to_string())
782 .unwrap_or_else(|| "".to_owned())),
783 &Type::FLOAT8 => Ok(portal
784 .parameter::<f64>(idx, param_type)?
785 .map(|v| v.to_string())
786 .unwrap_or_else(|| "".to_owned())),
787 &Type::DATE => Ok(portal
788 .parameter::<NaiveDate>(idx, param_type)?
789 .map(|v| v.format("%Y-%m-%d").to_string())
790 .unwrap_or_else(|| "".to_owned())),
791 &Type::TIMESTAMP => Ok(portal
792 .parameter::<NaiveDateTime>(idx, param_type)?
793 .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
794 .unwrap_or_else(|| "".to_owned())),
795 &Type::INTERVAL => Ok(portal
796 .parameter::<PgInterval>(idx, param_type)?
797 .map(|v| v.to_string())
798 .unwrap_or_else(|| "".to_owned())),
799 _ => Err(invalid_parameter_error(
800 "unsupported_parameter_type",
801 Some(param_type.to_string()),
802 )),
803 }
804}
805
806pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
807 let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
808 error_info.detail = detail;
809 PgWireError::UserError(Box::new(error_info))
810}
811
812fn to_timestamp_scalar_value<T>(
813 data: Option<T>,
814 unit: &TimestampType,
815 ctype: &ConcreteDataType,
816) -> PgWireResult<ScalarValue>
817where
818 T: Into<i64>,
819{
820 if let Some(n) = data {
821 Value::Timestamp(unit.create_timestamp(n.into()))
822 .try_to_scalar_value(ctype)
823 .map_err(convert_err)
824 } else {
825 Ok(ScalarValue::Null)
826 }
827}
828
829pub(super) fn parameters_to_scalar_values(
830 plan: &LogicalPlan,
831 portal: &Portal<SqlPlan>,
832) -> PgWireResult<Vec<ScalarValue>> {
833 let param_count = portal.parameter_len();
834 let mut results = Vec::with_capacity(param_count);
835
836 let client_param_types = &portal.statement.parameter_types;
837 let server_param_types = plan
838 .get_parameter_types()
839 .context(DataFusionSnafu)
840 .map_err(convert_err)?
841 .into_iter()
842 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
843 .collect::<HashMap<_, _>>();
844
845 for idx in 0..param_count {
846 let server_type = server_param_types
847 .get(&format!("${}", idx + 1))
848 .and_then(|t| t.as_ref());
849
850 let client_type = if let Some(Some(client_given_type)) = client_param_types.get(idx) {
851 client_given_type.clone()
852 } else if let Some(server_provided_type) = &server_type {
853 type_gt_to_pg(server_provided_type).map_err(convert_err)?
854 } else {
855 return Err(invalid_parameter_error(
856 "unknown_parameter_type",
857 Some(format!(
858 "Cannot get parameter type information for parameter {}",
859 idx
860 )),
861 ));
862 };
863
864 let value = match &client_type {
865 &Type::VARCHAR | &Type::TEXT => {
866 let data = portal.parameter::<String>(idx, &client_type)?;
867 if let Some(server_type) = &server_type {
868 match server_type {
869 ConcreteDataType::String(t) => {
870 if t.is_large() {
871 ScalarValue::LargeUtf8(data)
872 } else {
873 ScalarValue::Utf8(data)
874 }
875 }
876 _ => {
877 return Err(invalid_parameter_error(
878 "invalid_parameter_type",
879 Some(format!("Expected: {}, found: {}", server_type, client_type)),
880 ));
881 }
882 }
883 } else {
884 ScalarValue::Utf8(data)
885 }
886 }
887 &Type::BOOL => {
888 let data = portal.parameter::<bool>(idx, &client_type)?;
889 if let Some(server_type) = &server_type {
890 match server_type {
891 ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
892 _ => {
893 return Err(invalid_parameter_error(
894 "invalid_parameter_type",
895 Some(format!("Expected: {}, found: {}", server_type, client_type)),
896 ));
897 }
898 }
899 } else {
900 ScalarValue::Boolean(data)
901 }
902 }
903 &Type::INT2 => {
904 let data = portal.parameter::<i16>(idx, &client_type)?;
905 if let Some(server_type) = &server_type {
906 match server_type {
907 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
908 ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
909 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
910 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
911 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
912 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
913 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
914 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
915 ConcreteDataType::Timestamp(unit) => {
916 to_timestamp_scalar_value(data, unit, server_type)?
917 }
918 _ => {
919 return Err(invalid_parameter_error(
920 "invalid_parameter_type",
921 Some(format!("Expected: {}, found: {}", server_type, client_type)),
922 ));
923 }
924 }
925 } else {
926 ScalarValue::Int16(data)
927 }
928 }
929 &Type::INT4 => {
930 let data = portal.parameter::<i32>(idx, &client_type)?;
931 if let Some(server_type) = &server_type {
932 match server_type {
933 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
934 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
935 ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
936 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
937 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
938 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
939 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
940 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
941 ConcreteDataType::Timestamp(unit) => {
942 to_timestamp_scalar_value(data, unit, server_type)?
943 }
944 _ => {
945 return Err(invalid_parameter_error(
946 "invalid_parameter_type",
947 Some(format!("Expected: {}, found: {}", server_type, client_type)),
948 ));
949 }
950 }
951 } else {
952 ScalarValue::Int32(data)
953 }
954 }
955 &Type::INT8 => {
956 let data = portal.parameter::<i64>(idx, &client_type)?;
957 if let Some(server_type) = &server_type {
958 match server_type {
959 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
960 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
961 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
962 ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
963 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
964 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
965 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
966 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
967 ConcreteDataType::Timestamp(unit) => {
968 to_timestamp_scalar_value(data, unit, server_type)?
969 }
970 _ => {
971 return Err(invalid_parameter_error(
972 "invalid_parameter_type",
973 Some(format!("Expected: {}, found: {}", server_type, client_type)),
974 ));
975 }
976 }
977 } else {
978 ScalarValue::Int64(data)
979 }
980 }
981 &Type::FLOAT4 => {
982 let data = portal.parameter::<f32>(idx, &client_type)?;
983 if let Some(server_type) = &server_type {
984 match server_type {
985 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
986 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
987 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
988 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
989 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
990 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
991 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
992 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
993 ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
994 ConcreteDataType::Float64(_) => {
995 ScalarValue::Float64(data.map(|n| n as f64))
996 }
997 _ => {
998 return Err(invalid_parameter_error(
999 "invalid_parameter_type",
1000 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1001 ));
1002 }
1003 }
1004 } else {
1005 ScalarValue::Float32(data)
1006 }
1007 }
1008 &Type::FLOAT8 => {
1009 let data = portal.parameter::<f64>(idx, &client_type)?;
1010 if let Some(server_type) = &server_type {
1011 match server_type {
1012 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1013 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1014 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1015 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1016 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1017 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1018 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1019 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1020 ConcreteDataType::Float32(_) => {
1021 ScalarValue::Float32(data.map(|n| n as f32))
1022 }
1023 ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
1024 _ => {
1025 return Err(invalid_parameter_error(
1026 "invalid_parameter_type",
1027 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1028 ));
1029 }
1030 }
1031 } else {
1032 ScalarValue::Float64(data)
1033 }
1034 }
1035 &Type::TIMESTAMP => {
1036 let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
1037 if let Some(server_type) = &server_type {
1038 match server_type {
1039 ConcreteDataType::Timestamp(unit) => match *unit {
1040 TimestampType::Second(_) => ScalarValue::TimestampSecond(
1041 data.map(|ts| ts.and_utc().timestamp()),
1042 None,
1043 ),
1044 TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1045 data.map(|ts| ts.and_utc().timestamp_millis()),
1046 None,
1047 ),
1048 TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1049 data.map(|ts| ts.and_utc().timestamp_micros()),
1050 None,
1051 ),
1052 TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1053 data.and_then(|ts| ts.and_utc().timestamp_nanos_opt()),
1054 None,
1055 ),
1056 },
1057 _ => {
1058 return Err(invalid_parameter_error(
1059 "invalid_parameter_type",
1060 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1061 ));
1062 }
1063 }
1064 } else {
1065 ScalarValue::TimestampMillisecond(
1066 data.map(|ts| ts.and_utc().timestamp_millis()),
1067 None,
1068 )
1069 }
1070 }
1071 &Type::TIMESTAMPTZ => {
1072 let data = portal.parameter::<DateTime<FixedOffset>>(idx, &client_type)?;
1073 if let Some(server_type) = &server_type {
1074 match server_type {
1075 ConcreteDataType::Timestamp(unit) => match *unit {
1076 TimestampType::Second(_) => {
1077 ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None)
1078 }
1079 TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1080 data.map(|ts| ts.timestamp_millis()),
1081 None,
1082 ),
1083 TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1084 data.map(|ts| ts.timestamp_micros()),
1085 None,
1086 ),
1087 TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1088 data.and_then(|ts| ts.timestamp_nanos_opt()),
1089 None,
1090 ),
1091 },
1092 _ => {
1093 return Err(invalid_parameter_error(
1094 "invalid_parameter_type",
1095 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1096 ));
1097 }
1098 }
1099 } else {
1100 ScalarValue::TimestampMillisecond(data.map(|ts| ts.timestamp_millis()), None)
1101 }
1102 }
1103 &Type::DATE => {
1104 let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
1105 if let Some(server_type) = &server_type {
1106 match server_type {
1107 ConcreteDataType::Date(_) => ScalarValue::Date32(
1108 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1109 ),
1110 _ => {
1111 return Err(invalid_parameter_error(
1112 "invalid_parameter_type",
1113 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1114 ));
1115 }
1116 }
1117 } else {
1118 ScalarValue::Date32(
1119 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1120 )
1121 }
1122 }
1123 &Type::INTERVAL => {
1124 let data = portal.parameter::<PgInterval>(idx, &client_type)?;
1125 if let Some(server_type) = &server_type {
1126 match server_type {
1127 ConcreteDataType::Interval(IntervalType::YearMonth(_)) => {
1128 ScalarValue::IntervalYearMonth(
1129 data.map(|i| {
1130 if i.days != 0 || i.microseconds != 0 {
1131 Err(invalid_parameter_error(
1132 "invalid_parameter_type",
1133 Some(format!(
1134 "Expected: {}, found: {}",
1135 server_type, client_type
1136 )),
1137 ))
1138 } else {
1139 Ok(IntervalYearMonth::new(i.months).to_i32())
1140 }
1141 })
1142 .transpose()?,
1143 )
1144 }
1145 ConcreteDataType::Interval(IntervalType::DayTime(_)) => {
1146 ScalarValue::IntervalDayTime(
1147 data.map(|i| {
1148 if i.months != 0 || i.microseconds % 1000 != 0 {
1149 Err(invalid_parameter_error(
1150 "invalid_parameter_type",
1151 Some(format!(
1152 "Expected: {}, found: {}",
1153 server_type, client_type
1154 )),
1155 ))
1156 } else {
1157 Ok(IntervalDayTime::new(
1158 i.days,
1159 (i.microseconds / 1000) as i32,
1160 )
1161 .into())
1162 }
1163 })
1164 .transpose()?,
1165 )
1166 }
1167 ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
1168 ScalarValue::IntervalMonthDayNano(
1169 data.map(|i| IntervalMonthDayNano::from(i).into()),
1170 )
1171 }
1172 _ => {
1173 return Err(invalid_parameter_error(
1174 "invalid_parameter_type",
1175 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1176 ));
1177 }
1178 }
1179 } else {
1180 ScalarValue::IntervalMonthDayNano(
1181 data.map(|i| IntervalMonthDayNano::from(i).into()),
1182 )
1183 }
1184 }
1185 &Type::BYTEA => {
1186 let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
1187 if let Some(server_type) = &server_type {
1188 match server_type {
1189 ConcreteDataType::String(t) => {
1190 let s = data.map(|d| String::from_utf8_lossy(&d).to_string());
1191 if t.is_large() {
1192 ScalarValue::LargeUtf8(s)
1193 } else {
1194 ScalarValue::Utf8(s)
1195 }
1196 }
1197 ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
1198 _ => {
1199 return Err(invalid_parameter_error(
1200 "invalid_parameter_type",
1201 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1202 ));
1203 }
1204 }
1205 } else {
1206 ScalarValue::Binary(data)
1207 }
1208 }
1209 &Type::JSONB => {
1210 let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
1211 if let Some(server_type) = &server_type {
1212 match server_type {
1213 ConcreteDataType::Binary(_) => {
1214 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1215 }
1216 _ => {
1217 return Err(invalid_parameter_error(
1218 "invalid_parameter_type",
1219 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1220 ));
1221 }
1222 }
1223 } else {
1224 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1225 }
1226 }
1227 &Type::INT2_ARRAY => {
1228 let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
1229 if let Some(data) = data {
1230 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1231 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
1232 } else {
1233 ScalarValue::Null
1234 }
1235 }
1236 &Type::INT4_ARRAY => {
1237 let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
1238 if let Some(data) = data {
1239 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1240 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
1241 } else {
1242 ScalarValue::Null
1243 }
1244 }
1245 &Type::INT8_ARRAY => {
1246 let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
1247 if let Some(data) = data {
1248 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1249 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
1250 } else {
1251 ScalarValue::Null
1252 }
1253 }
1254 &Type::VARCHAR_ARRAY => {
1255 let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
1256 if let Some(data) = data {
1257 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1258 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
1259 } else {
1260 ScalarValue::Null
1261 }
1262 }
1263 _ => Err(invalid_parameter_error(
1264 "unsupported_parameter_value",
1265 Some(format!("Found type: {}", client_type)),
1266 ))?,
1267 };
1268
1269 results.push(value);
1270 }
1271
1272 Ok(results)
1273}
1274
1275pub(super) fn param_types_to_pg_types(
1276 param_types: &HashMap<String, Option<ConcreteDataType>>,
1277) -> Result<Vec<Type>> {
1278 let param_count = param_types.len();
1279 let mut types = Vec::with_capacity(param_count);
1280 for i in 0..param_count {
1281 if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
1282 let pg_type = type_gt_to_pg(param_type)?;
1283 types.push(pg_type);
1284 } else {
1285 types.push(Type::UNKNOWN);
1286 }
1287 }
1288 Ok(types)
1289}
1290
1291#[cfg(test)]
1292mod test {
1293 use std::sync::Arc;
1294
1295 use arrow::array::{
1296 Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
1297 };
1298 use arrow_schema::Field;
1299 use datatypes::schema::{ColumnSchema, Schema};
1300 use datatypes::vectors::{
1301 BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
1302 Int16Vector, Int32Vector, Int64Vector, IntervalDayTimeVector, IntervalMonthDayNanoVector,
1303 IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
1304 TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
1305 };
1306 use pgwire::api::Type;
1307 use pgwire::api::results::{FieldFormat, FieldInfo};
1308 use session::context::QueryContextBuilder;
1309
1310 use super::*;
1311
1312 #[test]
1313 fn test_schema_convert() {
1314 let column_schemas = vec![
1315 ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
1316 ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
1317 ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
1318 ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
1319 ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
1320 ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
1321 ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
1322 ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
1323 ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
1324 ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
1325 ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
1326 ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
1327 ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
1328 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1329 ColumnSchema::new(
1330 "timestamps",
1331 ConcreteDataType::timestamp_millisecond_datatype(),
1332 true,
1333 ),
1334 ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
1335 ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
1336 ColumnSchema::new(
1337 "intervals",
1338 ConcreteDataType::interval_month_day_nano_datatype(),
1339 true,
1340 ),
1341 ];
1342 let pg_field_info = vec![
1343 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1344 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1345 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1346 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1347 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1348 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1349 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1350 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1351 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1352 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1353 FieldInfo::new(
1354 "float32s".into(),
1355 None,
1356 None,
1357 Type::FLOAT4,
1358 FieldFormat::Text,
1359 ),
1360 FieldInfo::new(
1361 "float64s".into(),
1362 None,
1363 None,
1364 Type::FLOAT8,
1365 FieldFormat::Text,
1366 ),
1367 FieldInfo::new(
1368 "binaries".into(),
1369 None,
1370 None,
1371 Type::BYTEA,
1372 FieldFormat::Text,
1373 ),
1374 FieldInfo::new(
1375 "strings".into(),
1376 None,
1377 None,
1378 Type::VARCHAR,
1379 FieldFormat::Text,
1380 ),
1381 FieldInfo::new(
1382 "timestamps".into(),
1383 None,
1384 None,
1385 Type::TIMESTAMP,
1386 FieldFormat::Text,
1387 ),
1388 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1389 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1390 FieldInfo::new(
1391 "intervals".into(),
1392 None,
1393 None,
1394 Type::INTERVAL,
1395 FieldFormat::Text,
1396 ),
1397 ];
1398 let schema = Schema::new(column_schemas);
1399 let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
1400 assert_eq!(fs, pg_field_info);
1401 }
1402
1403 #[test]
1404 fn test_encode_text_format_data() {
1405 let schema = vec![
1406 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1407 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1408 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1409 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1410 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1411 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1412 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1413 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1414 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1415 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1416 FieldInfo::new(
1417 "float32s".into(),
1418 None,
1419 None,
1420 Type::FLOAT4,
1421 FieldFormat::Text,
1422 ),
1423 FieldInfo::new(
1424 "float64s".into(),
1425 None,
1426 None,
1427 Type::FLOAT8,
1428 FieldFormat::Text,
1429 ),
1430 FieldInfo::new(
1431 "strings".into(),
1432 None,
1433 None,
1434 Type::VARCHAR,
1435 FieldFormat::Text,
1436 ),
1437 FieldInfo::new(
1438 "binaries".into(),
1439 None,
1440 None,
1441 Type::BYTEA,
1442 FieldFormat::Text,
1443 ),
1444 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1445 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1446 FieldInfo::new(
1447 "timestamps".into(),
1448 None,
1449 None,
1450 Type::TIMESTAMP,
1451 FieldFormat::Text,
1452 ),
1453 FieldInfo::new(
1454 "interval_year_month".into(),
1455 None,
1456 None,
1457 Type::INTERVAL,
1458 FieldFormat::Text,
1459 ),
1460 FieldInfo::new(
1461 "interval_day_time".into(),
1462 None,
1463 None,
1464 Type::INTERVAL,
1465 FieldFormat::Text,
1466 ),
1467 FieldInfo::new(
1468 "interval_month_day_nano".into(),
1469 None,
1470 None,
1471 Type::INTERVAL,
1472 FieldFormat::Text,
1473 ),
1474 FieldInfo::new(
1475 "int_list".into(),
1476 None,
1477 None,
1478 Type::INT8_ARRAY,
1479 FieldFormat::Text,
1480 ),
1481 FieldInfo::new(
1482 "float_list".into(),
1483 None,
1484 None,
1485 Type::FLOAT8_ARRAY,
1486 FieldFormat::Text,
1487 ),
1488 FieldInfo::new(
1489 "string_list".into(),
1490 None,
1491 None,
1492 Type::VARCHAR_ARRAY,
1493 FieldFormat::Text,
1494 ),
1495 FieldInfo::new(
1496 "timestamp_list".into(),
1497 None,
1498 None,
1499 Type::TIMESTAMP_ARRAY,
1500 FieldFormat::Text,
1501 ),
1502 ];
1503
1504 let arrow_schema = arrow_schema::Schema::new(vec![
1505 Field::new("x", DataType::Null, true),
1506 Field::new("x", DataType::Boolean, true),
1507 Field::new("x", DataType::UInt8, true),
1508 Field::new("x", DataType::UInt16, true),
1509 Field::new("x", DataType::UInt32, true),
1510 Field::new("x", DataType::UInt64, true),
1511 Field::new("x", DataType::Int8, true),
1512 Field::new("x", DataType::Int16, true),
1513 Field::new("x", DataType::Int32, true),
1514 Field::new("x", DataType::Int64, true),
1515 Field::new("x", DataType::Float32, true),
1516 Field::new("x", DataType::Float64, true),
1517 Field::new("x", DataType::Utf8, true),
1518 Field::new("x", DataType::Binary, true),
1519 Field::new("x", DataType::Date32, true),
1520 Field::new("x", DataType::Time32(TimeUnit::Second), true),
1521 Field::new("x", DataType::Timestamp(TimeUnit::Second, None), true),
1522 Field::new("x", DataType::Interval(IntervalUnit::YearMonth), true),
1523 Field::new("x", DataType::Interval(IntervalUnit::DayTime), true),
1524 Field::new("x", DataType::Interval(IntervalUnit::MonthDayNano), true),
1525 Field::new(
1526 "x",
1527 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
1528 true,
1529 ),
1530 Field::new(
1531 "x",
1532 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
1533 true,
1534 ),
1535 Field::new(
1536 "x",
1537 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1538 true,
1539 ),
1540 Field::new(
1541 "x",
1542 DataType::List(Arc::new(Field::new(
1543 "item",
1544 DataType::Timestamp(TimeUnit::Second, None),
1545 true,
1546 ))),
1547 true,
1548 ),
1549 ]);
1550
1551 let mut builder = ListBuilder::new(Int64Builder::new());
1552 builder.append_value([Some(1i64), None, Some(2)]);
1553 builder.append_null();
1554 builder.append_value([Some(-1i64), None, Some(-2)]);
1555 let i64_list_array = builder.finish();
1556
1557 let mut builder = ListBuilder::new(Float64Builder::new());
1558 builder.append_value([Some(1.0f64), None, Some(2.0)]);
1559 builder.append_null();
1560 builder.append_value([Some(-1.0f64), None, Some(-2.0)]);
1561 let f64_list_array = builder.finish();
1562
1563 let mut builder = ListBuilder::new(StringBuilder::new());
1564 builder.append_value([Some("a"), None, Some("b")]);
1565 builder.append_null();
1566 builder.append_value([Some("c"), None, Some("d")]);
1567 let string_list_array = builder.finish();
1568
1569 let mut builder = ListBuilder::new(TimestampSecondBuilder::new());
1570 builder.append_value([Some(1i64), None, Some(2)]);
1571 builder.append_null();
1572 builder.append_value([Some(3i64), None, Some(4)]);
1573 let timestamp_list_array = builder.finish();
1574
1575 let values = vec![
1576 Arc::new(NullVector::new(3)) as VectorRef,
1577 Arc::new(BooleanVector::from(vec![Some(true), Some(false), None])),
1578 Arc::new(UInt8Vector::from(vec![Some(u8::MAX), Some(u8::MIN), None])),
1579 Arc::new(UInt16Vector::from(vec![
1580 Some(u16::MAX),
1581 Some(u16::MIN),
1582 None,
1583 ])),
1584 Arc::new(UInt32Vector::from(vec![
1585 Some(u32::MAX),
1586 Some(u32::MIN),
1587 None,
1588 ])),
1589 Arc::new(UInt64Vector::from(vec![
1590 Some(u64::MAX),
1591 Some(u64::MIN),
1592 None,
1593 ])),
1594 Arc::new(Int8Vector::from(vec![Some(i8::MAX), Some(i8::MIN), None])),
1595 Arc::new(Int16Vector::from(vec![
1596 Some(i16::MAX),
1597 Some(i16::MIN),
1598 None,
1599 ])),
1600 Arc::new(Int32Vector::from(vec![
1601 Some(i32::MAX),
1602 Some(i32::MIN),
1603 None,
1604 ])),
1605 Arc::new(Int64Vector::from(vec![
1606 Some(i64::MAX),
1607 Some(i64::MIN),
1608 None,
1609 ])),
1610 Arc::new(Float32Vector::from(vec![
1611 None,
1612 Some(f32::MAX),
1613 Some(f32::MIN),
1614 ])),
1615 Arc::new(Float64Vector::from(vec![
1616 None,
1617 Some(f64::MAX),
1618 Some(f64::MIN),
1619 ])),
1620 Arc::new(StringVector::from(vec![
1621 None,
1622 Some("hello"),
1623 Some("greptime"),
1624 ])),
1625 Arc::new(BinaryVector::from(vec![
1626 None,
1627 Some("hello".as_bytes().to_vec()),
1628 Some("world".as_bytes().to_vec()),
1629 ])),
1630 Arc::new(DateVector::from(vec![Some(1001), None, Some(1)])),
1631 Arc::new(TimeSecondVector::from(vec![Some(1001), None, Some(1)])),
1632 Arc::new(TimestampSecondVector::from(vec![
1633 Some(1000001),
1634 None,
1635 Some(1),
1636 ])),
1637 Arc::new(IntervalYearMonthVector::from(vec![Some(1), None, Some(2)])),
1638 Arc::new(IntervalDayTimeVector::from(vec![
1639 Some(arrow::datatypes::IntervalDayTime::new(1, 1)),
1640 None,
1641 Some(arrow::datatypes::IntervalDayTime::new(2, 2)),
1642 ])),
1643 Arc::new(IntervalMonthDayNanoVector::from(vec![
1644 Some(arrow::datatypes::IntervalMonthDayNano::new(1, 1, 10)),
1645 None,
1646 Some(arrow::datatypes::IntervalMonthDayNano::new(2, 2, 20)),
1647 ])),
1648 Arc::new(ListVector::from(i64_list_array)),
1649 Arc::new(ListVector::from(f64_list_array)),
1650 Arc::new(ListVector::from(string_list_array)),
1651 Arc::new(ListVector::from(timestamp_list_array)),
1652 ];
1653 let record_batch =
1654 RecordBatch::new(Arc::new(arrow_schema.try_into().unwrap()), values).unwrap();
1655
1656 let query_context = QueryContextBuilder::default()
1657 .configuration_parameter(Default::default())
1658 .build()
1659 .into();
1660 let schema = Arc::new(schema);
1661
1662 let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
1663 .filter_map(|x| x.ok())
1664 .collect::<Vec<_>>();
1665 assert_eq!(rows.len(), 3);
1666 for row in rows {
1667 assert_eq!(row.field_count, schema.len() as i16);
1668 }
1669 }
1670
1671 #[test]
1672 fn test_invalid_parameter() {
1673 let msg = "invalid_parameter_count";
1675 let error = invalid_parameter_error(msg, None);
1676 if let PgWireError::UserError(value) = error {
1677 assert_eq!("ERROR", value.severity);
1678 assert_eq!("22023", value.code);
1679 assert_eq!(msg, value.message);
1680 } else {
1681 panic!("test_invalid_parameter failed");
1682 }
1683 }
1684}