1mod bytea;
16mod datetime;
17mod error;
18mod interval;
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, AsArray};
24use arrow::datatypes::{
25 Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type,
26 Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
27 Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
29 TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
30};
31use arrow_schema::{DataType, IntervalUnit, TimeUnit};
32use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
33use common_decimal::Decimal128;
34use common_recordbatch::RecordBatch;
35use common_time::time::Time;
36use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp};
37use datafusion_common::ScalarValue;
38use datafusion_expr::LogicalPlan;
39use datatypes::arrow::datatypes::DataType as ArrowDataType;
40use datatypes::json::JsonStructureSettings;
41use datatypes::prelude::{ConcreteDataType, Value};
42use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
43use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
44use datatypes::value::StructValue;
45use pgwire::api::Type;
46use pgwire::api::portal::{Format, Portal};
47use pgwire::api::results::{DataRowEncoder, FieldInfo};
48use pgwire::error::{PgWireError, PgWireResult};
49use pgwire::messages::data::DataRow;
50use session::context::QueryContextRef;
51use session::session_config::PGByteaOutputValue;
52use snafu::ResultExt;
53
54use self::bytea::{EscapeOutputBytea, HexOutputBytea};
55use self::datetime::{StylingDate, StylingDateTime};
56pub use self::error::{PgErrorCode, PgErrorSeverity};
57use self::interval::PgInterval;
58use crate::SqlPlan;
59use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
60use crate::postgres::utils::convert_err;
61
62pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
63 origin
64 .column_schemas()
65 .iter()
66 .enumerate()
67 .map(|(idx, col)| {
68 Ok(FieldInfo::new(
69 col.name.clone(),
70 None,
71 None,
72 type_gt_to_pg(&col.data_type)?,
73 field_formats.format_for(idx),
74 ))
75 })
76 .collect::<Result<Vec<FieldInfo>>>()
77}
78
79fn encode_struct(
89 _query_ctx: &QueryContextRef,
90 struct_value: StructValue,
91 builder: &mut DataRowEncoder,
92) -> PgWireResult<()> {
93 let encoding_setting = JsonStructureSettings::Structured(None);
94 let json_value = encoding_setting
95 .decode(Value::Struct(struct_value))
96 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
97
98 builder.encode_field(&json_value)
99}
100
101fn encode_array(
102 query_ctx: &QueryContextRef,
103 array: ArrayRef,
104 builder: &mut DataRowEncoder,
105) -> PgWireResult<()> {
106 macro_rules! encode_primitive_array {
107 ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
108 let array = $array.iter().collect::<Vec<Option<$data_type>>>();
109 if array
110 .iter()
111 .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
112 {
113 builder.encode_field(
114 &array
115 .into_iter()
116 .map(|x| x.map(|i| i as $lower_type))
117 .collect::<Vec<Option<$lower_type>>>(),
118 )
119 } else {
120 builder.encode_field(
121 &array
122 .into_iter()
123 .map(|x| x.map(|i| i as $upper_type))
124 .collect::<Vec<Option<$upper_type>>>(),
125 )
126 }
127 }};
128 }
129
130 match array.data_type() {
131 DataType::Boolean => {
132 let array = array.as_boolean();
133 let array = array.iter().collect::<Vec<_>>();
134 builder.encode_field(&array)
135 }
136 DataType::Int8 => {
137 let array = array.as_primitive::<Int8Type>();
138 let array = array.iter().collect::<Vec<_>>();
139 builder.encode_field(&array)
140 }
141 DataType::Int16 => {
142 let array = array.as_primitive::<Int16Type>();
143 let array = array.iter().collect::<Vec<_>>();
144 builder.encode_field(&array)
145 }
146 DataType::Int32 => {
147 let array = array.as_primitive::<Int32Type>();
148 let array = array.iter().collect::<Vec<_>>();
149 builder.encode_field(&array)
150 }
151 DataType::Int64 => {
152 let array = array.as_primitive::<Int64Type>();
153 let array = array.iter().collect::<Vec<_>>();
154 builder.encode_field(&array)
155 }
156 DataType::UInt8 => {
157 let array = array.as_primitive::<UInt8Type>();
158 encode_primitive_array!(array, u8, i8, i16)
159 }
160 DataType::UInt16 => {
161 let array = array.as_primitive::<UInt16Type>();
162 encode_primitive_array!(array, u16, i16, i32)
163 }
164 DataType::UInt32 => {
165 let array = array.as_primitive::<UInt32Type>();
166 encode_primitive_array!(array, u32, i32, i64)
167 }
168 DataType::UInt64 => {
169 let array = array.as_primitive::<UInt64Type>();
170 let array = array.iter().collect::<Vec<_>>();
171 if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
172 builder.encode_field(
173 &array
174 .into_iter()
175 .map(|x| x.map(|i| i as i64))
176 .collect::<Vec<Option<i64>>>(),
177 )
178 } else {
179 builder.encode_field(
180 &array
181 .into_iter()
182 .map(|x| x.map(|i| i.to_string()))
183 .collect::<Vec<_>>(),
184 )
185 }
186 }
187 DataType::Float32 => {
188 let array = array.as_primitive::<Float32Type>();
189 let array = array.iter().collect::<Vec<_>>();
190 builder.encode_field(&array)
191 }
192 DataType::Float64 => {
193 let array = array.as_primitive::<Float64Type>();
194 let array = array.iter().collect::<Vec<_>>();
195 builder.encode_field(&array)
196 }
197 DataType::Binary => {
198 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
199
200 let array = array.as_binary::<i32>();
201 match *bytea_output {
202 PGByteaOutputValue::ESCAPE => {
203 let array = array
204 .iter()
205 .map(|v| v.map(EscapeOutputBytea))
206 .collect::<Vec<_>>();
207 builder.encode_field(&array)
208 }
209 PGByteaOutputValue::HEX => {
210 let array = array
211 .iter()
212 .map(|v| v.map(HexOutputBytea))
213 .collect::<Vec<_>>();
214 builder.encode_field(&array)
215 }
216 }
217 }
218 DataType::Utf8 => {
219 let array = array.as_string::<i32>();
220 let array = array.into_iter().collect::<Vec<_>>();
221 builder.encode_field(&array)
222 }
223 DataType::LargeUtf8 => {
224 let array = array.as_string::<i64>();
225 let array = array.into_iter().collect::<Vec<_>>();
226 builder.encode_field(&array)
227 }
228 DataType::Utf8View => {
229 let array = array.as_string_view();
230 let array = array.into_iter().collect::<Vec<_>>();
231 builder.encode_field(&array)
232 }
233 DataType::Date32 | DataType::Date64 => {
234 let iter: Box<dyn Iterator<Item = Option<i32>>> =
235 if matches!(array.data_type(), DataType::Date32) {
236 let array = array.as_primitive::<Date32Type>();
237 Box::new(array.into_iter())
238 } else {
239 let array = array.as_primitive::<Date64Type>();
240 Box::new(
244 array
245 .into_iter()
246 .map(|x| x.map(|i| (i / 86_400_000) as i32)),
247 )
248 };
249 let array = iter
250 .into_iter()
251 .map(|v| match v {
252 None => Ok(None),
253 Some(v) => {
254 if let Some(date) = Date::new(v).to_chrono_date() {
255 let (style, order) =
256 *query_ctx.configuration_parameter().pg_datetime_style();
257 Ok(Some(StylingDate(date, style, order)))
258 } else {
259 Err(convert_err(Error::Internal {
260 err_msg: format!("Failed to convert date to postgres type {v:?}",),
261 }))
262 }
263 }
264 })
265 .collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
266 builder.encode_field(&array)
267 }
268 DataType::Timestamp(time_unit, _) => {
269 let array = match time_unit {
270 TimeUnit::Second => {
271 let array = array.as_primitive::<TimestampSecondType>();
272 array.into_iter().collect::<Vec<_>>()
273 }
274 TimeUnit::Millisecond => {
275 let array = array.as_primitive::<TimestampMillisecondType>();
276 array.into_iter().collect::<Vec<_>>()
277 }
278 TimeUnit::Microsecond => {
279 let array = array.as_primitive::<TimestampMicrosecondType>();
280 array.into_iter().collect::<Vec<_>>()
281 }
282 TimeUnit::Nanosecond => {
283 let array = array.as_primitive::<TimestampNanosecondType>();
284 array.into_iter().collect::<Vec<_>>()
285 }
286 };
287 let time_unit = time_unit.into();
288 let array = array
289 .into_iter()
290 .map(|v| match v {
291 None => Ok(None),
292 Some(v) => {
293 let v = Timestamp::new(v, time_unit);
294 if let Some(datetime) =
295 v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
296 {
297 let (style, order) =
298 *query_ctx.configuration_parameter().pg_datetime_style();
299 Ok(Some(StylingDateTime(datetime, style, order)))
300 } else {
301 Err(convert_err(Error::Internal {
302 err_msg: format!("Failed to convert date to postgres type {v:?}",),
303 }))
304 }
305 }
306 })
307 .collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
308 builder.encode_field(&array)
309 }
310 DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
311 let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
312 TimeUnit::Second => {
313 let array = array.as_primitive::<Time32SecondType>();
314 Box::new(
315 array
316 .into_iter()
317 .map(|v| v.map(|i| Time::new_second(i as i64))),
318 )
319 }
320 TimeUnit::Millisecond => {
321 let array = array.as_primitive::<Time32MillisecondType>();
322 Box::new(
323 array
324 .into_iter()
325 .map(|v| v.map(|i| Time::new_millisecond(i as i64))),
326 )
327 }
328 TimeUnit::Microsecond => {
329 let array = array.as_primitive::<Time64MicrosecondType>();
330 Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
331 }
332 TimeUnit::Nanosecond => {
333 let array = array.as_primitive::<Time64NanosecondType>();
334 Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
335 }
336 };
337 let array = iter
338 .into_iter()
339 .map(|v| v.and_then(|v| v.to_chrono_time()))
340 .collect::<Vec<Option<NaiveTime>>>();
341 builder.encode_field(&array)
342 }
343 DataType::Interval(interval_unit) => {
344 let array = match interval_unit {
345 IntervalUnit::YearMonth => {
346 let array = array.as_primitive::<IntervalYearMonthType>();
347 array
348 .into_iter()
349 .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
350 .collect::<Vec<_>>()
351 }
352 IntervalUnit::DayTime => {
353 let array = array.as_primitive::<IntervalDayTimeType>();
354 array
355 .into_iter()
356 .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
357 .collect::<Vec<_>>()
358 }
359 IntervalUnit::MonthDayNano => {
360 let array = array.as_primitive::<IntervalMonthDayNanoType>();
361 array
362 .into_iter()
363 .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
364 .collect::<Vec<_>>()
365 }
366 };
367 builder.encode_field(&array)
368 }
369 DataType::Decimal128(precision, scale) => {
370 let array = array.as_primitive::<Decimal128Type>();
371 let array = array
372 .into_iter()
373 .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
374 .collect::<Vec<_>>();
375 builder.encode_field(&array)
376 }
377 _ => Err(convert_err(Error::Internal {
378 err_msg: format!(
379 "cannot write array type {:?} in postgres protocol: unimplemented",
380 array.data_type()
381 ),
382 })),
383 }
384}
385
386pub(crate) struct RecordBatchRowIterator {
387 query_ctx: QueryContextRef,
388 pg_schema: Arc<Vec<FieldInfo>>,
389 schema: SchemaRef,
390 record_batch: arrow::record_batch::RecordBatch,
391 i: usize,
392}
393
394impl Iterator for RecordBatchRowIterator {
395 type Item = PgWireResult<DataRow>;
396
397 fn next(&mut self) -> Option<Self::Item> {
398 let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
399 if self.i < self.record_batch.num_rows() {
400 if let Err(e) = self.encode_row(self.i, &mut encoder) {
401 return Some(Err(e));
402 }
403 self.i += 1;
404 Some(Ok(encoder.take_row()))
405 } else {
406 None
407 }
408 }
409}
410
411impl RecordBatchRowIterator {
412 pub(crate) fn new(
413 query_ctx: QueryContextRef,
414 pg_schema: Arc<Vec<FieldInfo>>,
415 record_batch: RecordBatch,
416 ) -> Self {
417 let schema = record_batch.schema.clone();
418 let record_batch = record_batch.into_df_record_batch();
419 Self {
420 query_ctx,
421 pg_schema,
422 schema,
423 record_batch,
424 i: 0,
425 }
426 }
427
428 fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
429 for (j, column) in self.record_batch.columns().iter().enumerate() {
430 if column.is_null(i) {
431 encoder.encode_field(&None::<&i8>)?;
432 continue;
433 }
434
435 match column.data_type() {
436 DataType::Null => {
437 encoder.encode_field(&None::<&i8>)?;
438 }
439 DataType::Boolean => {
440 let array = column.as_boolean();
441 encoder.encode_field(&array.value(i))?;
442 }
443 DataType::UInt8 => {
444 let array = column.as_primitive::<UInt8Type>();
445 let value = array.value(i);
446 if value <= i8::MAX as u8 {
447 encoder.encode_field(&(value as i8))?;
448 } else {
449 encoder.encode_field(&(value as i16))?;
450 }
451 }
452 DataType::UInt16 => {
453 let array = column.as_primitive::<UInt16Type>();
454 let value = array.value(i);
455 if value <= i16::MAX as u16 {
456 encoder.encode_field(&(value as i16))?;
457 } else {
458 encoder.encode_field(&(value as i32))?;
459 }
460 }
461 DataType::UInt32 => {
462 let array = column.as_primitive::<UInt32Type>();
463 let value = array.value(i);
464 if value <= i32::MAX as u32 {
465 encoder.encode_field(&(value as i32))?;
466 } else {
467 encoder.encode_field(&(value as i64))?;
468 }
469 }
470 DataType::UInt64 => {
471 let array = column.as_primitive::<UInt64Type>();
472 let value = array.value(i);
473 if value <= i64::MAX as u64 {
474 encoder.encode_field(&(value as i64))?;
475 } else {
476 encoder.encode_field(&value.to_string())?;
477 }
478 }
479 DataType::Int8 => {
480 let array = column.as_primitive::<Int8Type>();
481 encoder.encode_field(&array.value(i))?;
482 }
483 DataType::Int16 => {
484 let array = column.as_primitive::<Int16Type>();
485 encoder.encode_field(&array.value(i))?;
486 }
487 DataType::Int32 => {
488 let array = column.as_primitive::<Int32Type>();
489 encoder.encode_field(&array.value(i))?;
490 }
491 DataType::Int64 => {
492 let array = column.as_primitive::<Int64Type>();
493 encoder.encode_field(&array.value(i))?;
494 }
495 DataType::Float32 => {
496 let array = column.as_primitive::<Float32Type>();
497 encoder.encode_field(&array.value(i))?;
498 }
499 DataType::Float64 => {
500 let array = column.as_primitive::<Float64Type>();
501 encoder.encode_field(&array.value(i))?;
502 }
503 DataType::Utf8 => {
504 let array = column.as_string::<i32>();
505 let value = array.value(i);
506 encoder.encode_field(&value)?;
507 }
508 DataType::Utf8View => {
509 let array = column.as_string_view();
510 let value = array.value(i);
511 encoder.encode_field(&value)?;
512 }
513 DataType::LargeUtf8 => {
514 let array = column.as_string::<i64>();
515 let value = array.value(i);
516 encoder.encode_field(&value)?;
517 }
518 DataType::Binary => {
519 let array = column.as_binary::<i32>();
520 let v = array.value(i);
521 encode_bytes(
522 &self.schema.column_schemas()[j],
523 v,
524 encoder,
525 &self.query_ctx,
526 )?;
527 }
528 DataType::BinaryView => {
529 let array = column.as_binary_view();
530 let v = array.value(i);
531 encode_bytes(
532 &self.schema.column_schemas()[j],
533 v,
534 encoder,
535 &self.query_ctx,
536 )?;
537 }
538 DataType::LargeBinary => {
539 let array = column.as_binary::<i64>();
540 let v = array.value(i);
541 encode_bytes(
542 &self.schema.column_schemas()[j],
543 v,
544 encoder,
545 &self.query_ctx,
546 )?;
547 }
548 DataType::Date32 | DataType::Date64 => {
549 let v = if matches!(column.data_type(), DataType::Date32) {
550 let array = column.as_primitive::<Date32Type>();
551 array.value(i)
552 } else {
553 let array = column.as_primitive::<Date64Type>();
554 (array.value(i) / 86_400_000) as i32
558 };
559 let v = Date::new(v);
560 let date = v.to_chrono_date().map(|v| {
561 let (style, order) =
562 *self.query_ctx.configuration_parameter().pg_datetime_style();
563 StylingDate(v, style, order)
564 });
565 encoder.encode_field(&date)?;
566 }
567 DataType::Timestamp(_, _) => {
568 let v = datatypes::arrow_array::timestamp_array_value(column, i);
569 let datetime = v
570 .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
571 .map(|v| {
572 let (style, order) =
573 *self.query_ctx.configuration_parameter().pg_datetime_style();
574 StylingDateTime(v, style, order)
575 });
576 encoder.encode_field(&datetime)?;
577 }
578 DataType::Interval(interval_unit) => match interval_unit {
579 IntervalUnit::YearMonth => {
580 let array = column.as_primitive::<IntervalYearMonthType>();
581 let v: IntervalYearMonth = array.value(i).into();
582 encoder.encode_field(&PgInterval::from(v))?;
583 }
584 IntervalUnit::DayTime => {
585 let array = column.as_primitive::<IntervalDayTimeType>();
586 let v: IntervalDayTime = array.value(i).into();
587 encoder.encode_field(&PgInterval::from(v))?;
588 }
589 IntervalUnit::MonthDayNano => {
590 let array = column.as_primitive::<IntervalMonthDayNanoType>();
591 let v: IntervalMonthDayNano = array.value(i).into();
592 encoder.encode_field(&PgInterval::from(v))?;
593 }
594 },
595 DataType::Duration(_) => {
596 let d = datatypes::arrow_array::duration_array_value(column, i);
597 match PgInterval::try_from(d) {
598 Ok(i) => encoder.encode_field(&i)?,
599 Err(e) => {
600 return Err(convert_err(Error::Internal {
601 err_msg: e.to_string(),
602 }));
603 }
604 }
605 }
606 DataType::List(_) => {
607 let array = column.as_list::<i32>();
608 let items = array.value(i);
609 encode_array(&self.query_ctx, items, encoder)?;
610 }
611 DataType::Struct(_) => {
612 encode_struct(&self.query_ctx, Default::default(), encoder)?;
613 }
614 DataType::Time32(_) | DataType::Time64(_) => {
615 let v = datatypes::arrow_array::time_array_value(column, i);
616 encoder.encode_field(&v.to_chrono_time())?;
617 }
618 DataType::Decimal128(precision, scale) => {
619 let array = column.as_primitive::<Decimal128Type>();
620 let v = Decimal128::new(array.value(i), *precision, *scale);
621 encoder.encode_field(&v.to_string())?;
622 }
623 _ => {
624 return Err(convert_err(Error::Internal {
625 err_msg: format!(
626 "cannot convert datatype {} to postgres",
627 column.data_type()
628 ),
629 }));
630 }
631 }
632 }
633 Ok(())
634 }
635}
636
637fn encode_bytes(
638 schema: &ColumnSchema,
639 v: &[u8],
640 encoder: &mut DataRowEncoder,
641 query_ctx: &QueryContextRef,
642) -> PgWireResult<()> {
643 if let ConcreteDataType::Json(_) = &schema.data_type {
644 let s = jsonb_to_string(v).map_err(convert_err)?;
645 encoder.encode_field(&s)
646 } else {
647 let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
648 match *bytea_output {
649 PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
650 PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
651 }
652 }
653}
654
655pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
656 match origin {
657 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
658 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
659 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
660 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
661 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
662 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
663 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
664 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
665 &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
666 &ConcreteDataType::String(_) => Ok(Type::VARCHAR),
667 &ConcreteDataType::Date(_) => Ok(Type::DATE),
668 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
669 &ConcreteDataType::Time(_) => Ok(Type::TIME),
670 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL),
671 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC),
672 &ConcreteDataType::Json(_) => Ok(Type::JSON),
673 ConcreteDataType::List(list) => match list.item_type() {
674 &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
675 &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
676 &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
677 &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
678 &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
679 &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
680 &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
681 &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
682 &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
683 &ConcreteDataType::String(_) => Ok(Type::VARCHAR_ARRAY),
684 &ConcreteDataType::Date(_) => Ok(Type::DATE_ARRAY),
685 &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP_ARRAY),
686 &ConcreteDataType::Time(_) => Ok(Type::TIME_ARRAY),
687 &ConcreteDataType::Interval(_) => Ok(Type::INTERVAL_ARRAY),
688 &ConcreteDataType::Decimal128(_) => Ok(Type::NUMERIC_ARRAY),
689 &ConcreteDataType::Json(_) => Ok(Type::JSON_ARRAY),
690 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL_ARRAY),
691 &ConcreteDataType::Struct(_) => Ok(Type::JSON_ARRAY),
692 &ConcreteDataType::Dictionary(_)
693 | &ConcreteDataType::Vector(_)
694 | &ConcreteDataType::List(_) => server_error::UnsupportedDataTypeSnafu {
695 data_type: origin,
696 reason: "not implemented",
697 }
698 .fail(),
699 },
700 &ConcreteDataType::Dictionary(_) => server_error::UnsupportedDataTypeSnafu {
701 data_type: origin,
702 reason: "not implemented",
703 }
704 .fail(),
705 &ConcreteDataType::Duration(_) => Ok(Type::INTERVAL),
706 &ConcreteDataType::Struct(_) => Ok(Type::JSON),
707 }
708}
709
710#[allow(dead_code)]
711pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
712 match origin {
714 &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
715 &Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
716 &Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
717 &Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
718 &Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
719 &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
720 &Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype(
721 common_time::timestamp::TimeUnit::Millisecond,
722 )),
723 &Type::DATE => Ok(ConcreteDataType::date_datatype()),
724 &Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
725 common_time::timestamp::TimeUnit::Microsecond,
726 )),
727 &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
728 ConcreteDataType::int8_datatype(),
729 ))),
730 &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
731 ConcreteDataType::int16_datatype(),
732 ))),
733 &Type::INT4_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
734 ConcreteDataType::int32_datatype(),
735 ))),
736 &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
737 ConcreteDataType::int64_datatype(),
738 ))),
739 &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
740 ConcreteDataType::string_datatype(),
741 ))),
742 _ => server_error::InternalSnafu {
743 err_msg: format!("unimplemented datatype {origin:?}"),
744 }
745 .fail(),
746 }
747}
748
749pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
750 let param_type = portal
753 .statement
754 .parameter_types
755 .get(idx)
756 .unwrap()
757 .as_ref()
758 .unwrap_or(&Type::UNKNOWN);
759 match param_type {
760 &Type::VARCHAR | &Type::TEXT => Ok(format!(
761 "'{}'",
762 portal
763 .parameter::<String>(idx, param_type)?
764 .as_deref()
765 .unwrap_or("")
766 )),
767 &Type::BOOL => Ok(portal
768 .parameter::<bool>(idx, param_type)?
769 .map(|v| v.to_string())
770 .unwrap_or_else(|| "".to_owned())),
771 &Type::INT4 => Ok(portal
772 .parameter::<i32>(idx, param_type)?
773 .map(|v| v.to_string())
774 .unwrap_or_else(|| "".to_owned())),
775 &Type::INT8 => Ok(portal
776 .parameter::<i64>(idx, param_type)?
777 .map(|v| v.to_string())
778 .unwrap_or_else(|| "".to_owned())),
779 &Type::FLOAT4 => Ok(portal
780 .parameter::<f32>(idx, param_type)?
781 .map(|v| v.to_string())
782 .unwrap_or_else(|| "".to_owned())),
783 &Type::FLOAT8 => Ok(portal
784 .parameter::<f64>(idx, param_type)?
785 .map(|v| v.to_string())
786 .unwrap_or_else(|| "".to_owned())),
787 &Type::DATE => Ok(portal
788 .parameter::<NaiveDate>(idx, param_type)?
789 .map(|v| v.format("%Y-%m-%d").to_string())
790 .unwrap_or_else(|| "".to_owned())),
791 &Type::TIMESTAMP => Ok(portal
792 .parameter::<NaiveDateTime>(idx, param_type)?
793 .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
794 .unwrap_or_else(|| "".to_owned())),
795 &Type::INTERVAL => Ok(portal
796 .parameter::<PgInterval>(idx, param_type)?
797 .map(|v| v.to_string())
798 .unwrap_or_else(|| "".to_owned())),
799 _ => Err(invalid_parameter_error(
800 "unsupported_parameter_type",
801 Some(param_type.to_string()),
802 )),
803 }
804}
805
806pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
807 let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
808 error_info.detail = detail;
809 PgWireError::UserError(Box::new(error_info))
810}
811
812fn to_timestamp_scalar_value<T>(
813 data: Option<T>,
814 unit: &TimestampType,
815 ctype: &ConcreteDataType,
816) -> PgWireResult<ScalarValue>
817where
818 T: Into<i64>,
819{
820 if let Some(n) = data {
821 Value::Timestamp(unit.create_timestamp(n.into()))
822 .try_to_scalar_value(ctype)
823 .map_err(convert_err)
824 } else {
825 Ok(ScalarValue::Null)
826 }
827}
828
829pub(super) fn parameters_to_scalar_values(
830 plan: &LogicalPlan,
831 portal: &Portal<SqlPlan>,
832) -> PgWireResult<Vec<ScalarValue>> {
833 let param_count = portal.parameter_len();
834 let mut results = Vec::with_capacity(param_count);
835
836 let client_param_types = &portal.statement.parameter_types;
837 let server_param_types = plan
838 .get_parameter_types()
839 .context(DataFusionSnafu)
840 .map_err(convert_err)?
841 .into_iter()
842 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
843 .collect::<HashMap<_, _>>();
844
845 for idx in 0..param_count {
846 let server_type = server_param_types
847 .get(&format!("${}", idx + 1))
848 .and_then(|t| t.as_ref());
849
850 let client_type = if let Some(Some(client_given_type)) = client_param_types.get(idx) {
851 client_given_type.clone()
852 } else if let Some(server_provided_type) = &server_type {
853 type_gt_to_pg(server_provided_type).map_err(convert_err)?
854 } else {
855 return Err(invalid_parameter_error(
856 "unknown_parameter_type",
857 Some(format!(
858 "Cannot get parameter type information for parameter {}",
859 idx
860 )),
861 ));
862 };
863
864 let value = match &client_type {
865 &Type::VARCHAR | &Type::TEXT => {
866 let data = portal.parameter::<String>(idx, &client_type)?;
867 if let Some(server_type) = &server_type {
868 match server_type {
869 ConcreteDataType::String(t) => {
870 if t.is_large() {
871 ScalarValue::LargeUtf8(data)
872 } else {
873 ScalarValue::Utf8(data)
874 }
875 }
876 _ => {
877 return Err(invalid_parameter_error(
878 "invalid_parameter_type",
879 Some(format!("Expected: {}, found: {}", server_type, client_type)),
880 ));
881 }
882 }
883 } else {
884 ScalarValue::Utf8(data)
885 }
886 }
887 &Type::BOOL => {
888 let data = portal.parameter::<bool>(idx, &client_type)?;
889 if let Some(server_type) = &server_type {
890 match server_type {
891 ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
892 _ => {
893 return Err(invalid_parameter_error(
894 "invalid_parameter_type",
895 Some(format!("Expected: {}, found: {}", server_type, client_type)),
896 ));
897 }
898 }
899 } else {
900 ScalarValue::Boolean(data)
901 }
902 }
903 &Type::INT2 => {
904 let data = portal.parameter::<i16>(idx, &client_type)?;
905 if let Some(server_type) = &server_type {
906 match server_type {
907 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
908 ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
909 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
910 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
911 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
912 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
913 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
914 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
915 ConcreteDataType::Timestamp(unit) => {
916 to_timestamp_scalar_value(data, unit, server_type)?
917 }
918 _ => {
919 return Err(invalid_parameter_error(
920 "invalid_parameter_type",
921 Some(format!("Expected: {}, found: {}", server_type, client_type)),
922 ));
923 }
924 }
925 } else {
926 ScalarValue::Int16(data)
927 }
928 }
929 &Type::INT4 => {
930 let data = portal.parameter::<i32>(idx, &client_type)?;
931 if let Some(server_type) = &server_type {
932 match server_type {
933 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
934 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
935 ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
936 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
937 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
938 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
939 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
940 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
941 ConcreteDataType::Timestamp(unit) => {
942 to_timestamp_scalar_value(data, unit, server_type)?
943 }
944 _ => {
945 return Err(invalid_parameter_error(
946 "invalid_parameter_type",
947 Some(format!("Expected: {}, found: {}", server_type, client_type)),
948 ));
949 }
950 }
951 } else {
952 ScalarValue::Int32(data)
953 }
954 }
955 &Type::INT8 => {
956 let data = portal.parameter::<i64>(idx, &client_type)?;
957 if let Some(server_type) = &server_type {
958 match server_type {
959 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
960 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
961 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
962 ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
963 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
964 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
965 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
966 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
967 ConcreteDataType::Timestamp(unit) => {
968 to_timestamp_scalar_value(data, unit, server_type)?
969 }
970 _ => {
971 return Err(invalid_parameter_error(
972 "invalid_parameter_type",
973 Some(format!("Expected: {}, found: {}", server_type, client_type)),
974 ));
975 }
976 }
977 } else {
978 ScalarValue::Int64(data)
979 }
980 }
981 &Type::FLOAT4 => {
982 let data = portal.parameter::<f32>(idx, &client_type)?;
983 if let Some(server_type) = &server_type {
984 match server_type {
985 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
986 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
987 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
988 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
989 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
990 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
991 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
992 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
993 ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
994 ConcreteDataType::Float64(_) => {
995 ScalarValue::Float64(data.map(|n| n as f64))
996 }
997 _ => {
998 return Err(invalid_parameter_error(
999 "invalid_parameter_type",
1000 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1001 ));
1002 }
1003 }
1004 } else {
1005 ScalarValue::Float32(data)
1006 }
1007 }
1008 &Type::FLOAT8 => {
1009 let data = portal.parameter::<f64>(idx, &client_type)?;
1010 if let Some(server_type) = &server_type {
1011 match server_type {
1012 ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
1013 ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
1014 ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
1015 ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
1016 ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
1017 ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
1018 ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
1019 ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
1020 ConcreteDataType::Float32(_) => {
1021 ScalarValue::Float32(data.map(|n| n as f32))
1022 }
1023 ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
1024 _ => {
1025 return Err(invalid_parameter_error(
1026 "invalid_parameter_type",
1027 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1028 ));
1029 }
1030 }
1031 } else {
1032 ScalarValue::Float64(data)
1033 }
1034 }
1035 &Type::TIMESTAMP => {
1036 let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
1037 if let Some(server_type) = &server_type {
1038 match server_type {
1039 ConcreteDataType::Timestamp(unit) => match *unit {
1040 TimestampType::Second(_) => ScalarValue::TimestampSecond(
1041 data.map(|ts| ts.and_utc().timestamp()),
1042 None,
1043 ),
1044 TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1045 data.map(|ts| ts.and_utc().timestamp_millis()),
1046 None,
1047 ),
1048 TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1049 data.map(|ts| ts.and_utc().timestamp_micros()),
1050 None,
1051 ),
1052 TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1053 data.and_then(|ts| ts.and_utc().timestamp_nanos_opt()),
1054 None,
1055 ),
1056 },
1057 _ => {
1058 return Err(invalid_parameter_error(
1059 "invalid_parameter_type",
1060 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1061 ));
1062 }
1063 }
1064 } else {
1065 ScalarValue::TimestampMillisecond(
1066 data.map(|ts| ts.and_utc().timestamp_millis()),
1067 None,
1068 )
1069 }
1070 }
1071 &Type::TIMESTAMPTZ => {
1072 let data = portal.parameter::<DateTime<FixedOffset>>(idx, &client_type)?;
1073 if let Some(server_type) = &server_type {
1074 match server_type {
1075 ConcreteDataType::Timestamp(unit) => match *unit {
1076 TimestampType::Second(_) => {
1077 ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None)
1078 }
1079 TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
1080 data.map(|ts| ts.timestamp_millis()),
1081 None,
1082 ),
1083 TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
1084 data.map(|ts| ts.timestamp_micros()),
1085 None,
1086 ),
1087 TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
1088 data.and_then(|ts| ts.timestamp_nanos_opt()),
1089 None,
1090 ),
1091 },
1092 _ => {
1093 return Err(invalid_parameter_error(
1094 "invalid_parameter_type",
1095 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1096 ));
1097 }
1098 }
1099 } else {
1100 ScalarValue::TimestampMillisecond(data.map(|ts| ts.timestamp_millis()), None)
1101 }
1102 }
1103 &Type::DATE => {
1104 let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
1105 if let Some(server_type) = &server_type {
1106 match server_type {
1107 ConcreteDataType::Date(_) => ScalarValue::Date32(
1108 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1109 ),
1110 _ => {
1111 return Err(invalid_parameter_error(
1112 "invalid_parameter_type",
1113 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1114 ));
1115 }
1116 }
1117 } else {
1118 ScalarValue::Date32(
1119 data.map(|d| (d - DateTime::UNIX_EPOCH.date_naive()).num_days() as i32),
1120 )
1121 }
1122 }
1123 &Type::INTERVAL => {
1124 let data = portal.parameter::<PgInterval>(idx, &client_type)?;
1125 if let Some(server_type) = &server_type {
1126 match server_type {
1127 ConcreteDataType::Interval(IntervalType::YearMonth(_)) => {
1128 ScalarValue::IntervalYearMonth(
1129 data.map(|i| {
1130 if i.days != 0 || i.microseconds != 0 {
1131 Err(invalid_parameter_error(
1132 "invalid_parameter_type",
1133 Some(format!(
1134 "Expected: {}, found: {}",
1135 server_type, client_type
1136 )),
1137 ))
1138 } else {
1139 Ok(IntervalYearMonth::new(i.months).to_i32())
1140 }
1141 })
1142 .transpose()?,
1143 )
1144 }
1145 ConcreteDataType::Interval(IntervalType::DayTime(_)) => {
1146 ScalarValue::IntervalDayTime(
1147 data.map(|i| {
1148 if i.months != 0 || i.microseconds % 1000 != 0 {
1149 Err(invalid_parameter_error(
1150 "invalid_parameter_type",
1151 Some(format!(
1152 "Expected: {}, found: {}",
1153 server_type, client_type
1154 )),
1155 ))
1156 } else {
1157 Ok(IntervalDayTime::new(
1158 i.days,
1159 (i.microseconds / 1000) as i32,
1160 )
1161 .into())
1162 }
1163 })
1164 .transpose()?,
1165 )
1166 }
1167 ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
1168 ScalarValue::IntervalMonthDayNano(
1169 data.map(|i| IntervalMonthDayNano::from(i).into()),
1170 )
1171 }
1172 _ => {
1173 return Err(invalid_parameter_error(
1174 "invalid_parameter_type",
1175 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1176 ));
1177 }
1178 }
1179 } else {
1180 ScalarValue::IntervalMonthDayNano(
1181 data.map(|i| IntervalMonthDayNano::from(i).into()),
1182 )
1183 }
1184 }
1185 &Type::BYTEA => {
1186 let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
1187 if let Some(server_type) = &server_type {
1188 match server_type {
1189 ConcreteDataType::String(t) => {
1190 let s = data.map(|d| String::from_utf8_lossy(&d).to_string());
1191 if t.is_large() {
1192 ScalarValue::LargeUtf8(s)
1193 } else {
1194 ScalarValue::Utf8(s)
1195 }
1196 }
1197 ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
1198 _ => {
1199 return Err(invalid_parameter_error(
1200 "invalid_parameter_type",
1201 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1202 ));
1203 }
1204 }
1205 } else {
1206 ScalarValue::Binary(data)
1207 }
1208 }
1209 &Type::JSONB => {
1210 let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
1211 if let Some(server_type) = &server_type {
1212 match server_type {
1213 ConcreteDataType::Binary(_) => {
1214 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1215 }
1216 _ => {
1217 return Err(invalid_parameter_error(
1218 "invalid_parameter_type",
1219 Some(format!("Expected: {}, found: {}", server_type, client_type)),
1220 ));
1221 }
1222 }
1223 } else {
1224 ScalarValue::Binary(data.map(|d| d.to_string().into_bytes()))
1225 }
1226 }
1227 &Type::INT2_ARRAY => {
1228 let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
1229 if let Some(data) = data {
1230 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1231 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
1232 } else {
1233 ScalarValue::Null
1234 }
1235 }
1236 &Type::INT4_ARRAY => {
1237 let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
1238 if let Some(data) = data {
1239 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1240 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
1241 } else {
1242 ScalarValue::Null
1243 }
1244 }
1245 &Type::INT8_ARRAY => {
1246 let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
1247 if let Some(data) = data {
1248 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1249 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
1250 } else {
1251 ScalarValue::Null
1252 }
1253 }
1254 &Type::VARCHAR_ARRAY => {
1255 let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
1256 if let Some(data) = data {
1257 let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
1258 ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
1259 } else {
1260 ScalarValue::Null
1261 }
1262 }
1263 &Type::TIMESTAMP_ARRAY => {
1264 let data = portal.parameter::<Vec<NaiveDateTime>>(idx, &client_type)?;
1265 if let Some(data) = data {
1266 if let Some(ConcreteDataType::List(list_type)) = &server_type {
1267 match list_type.item_type() {
1268 ConcreteDataType::Timestamp(unit) => match *unit {
1269 TimestampType::Second(_) => {
1270 let values = data
1271 .into_iter()
1272 .map(|ts| {
1273 ScalarValue::TimestampSecond(
1274 Some(ts.and_utc().timestamp()),
1275 None,
1276 )
1277 })
1278 .collect::<Vec<_>>();
1279 ScalarValue::List(ScalarValue::new_list(
1280 &values,
1281 &ArrowDataType::Timestamp(TimeUnit::Second, None),
1282 true,
1283 ))
1284 }
1285 TimestampType::Millisecond(_) => {
1286 let values = data
1287 .into_iter()
1288 .map(|ts| {
1289 ScalarValue::TimestampMillisecond(
1290 Some(ts.and_utc().timestamp_millis()),
1291 None,
1292 )
1293 })
1294 .collect::<Vec<_>>();
1295 ScalarValue::List(ScalarValue::new_list(
1296 &values,
1297 &ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
1298 true,
1299 ))
1300 }
1301 TimestampType::Microsecond(_) => {
1302 let values = data
1303 .into_iter()
1304 .map(|ts| {
1305 ScalarValue::TimestampMicrosecond(
1306 Some(ts.and_utc().timestamp_micros()),
1307 None,
1308 )
1309 })
1310 .collect::<Vec<_>>();
1311 ScalarValue::List(ScalarValue::new_list(
1312 &values,
1313 &ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1314 true,
1315 ))
1316 }
1317 TimestampType::Nanosecond(_) => {
1318 let values = data
1319 .into_iter()
1320 .filter_map(|ts| {
1321 ts.and_utc().timestamp_nanos_opt().map(|nanos| {
1322 ScalarValue::TimestampNanosecond(Some(nanos), None)
1323 })
1324 })
1325 .collect::<Vec<_>>();
1326 ScalarValue::List(ScalarValue::new_list(
1327 &values,
1328 &ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1329 true,
1330 ))
1331 }
1332 },
1333 _ => {
1334 return Err(invalid_parameter_error(
1335 "invalid_parameter_type",
1336 Some(format!(
1337 "Expected: {}, found: {}",
1338 list_type.item_type(),
1339 client_type
1340 )),
1341 ));
1342 }
1343 }
1344 } else {
1345 let values = data
1347 .into_iter()
1348 .map(|ts| {
1349 ScalarValue::TimestampMillisecond(
1350 Some(ts.and_utc().timestamp_millis()),
1351 None,
1352 )
1353 })
1354 .collect::<Vec<_>>();
1355 ScalarValue::List(ScalarValue::new_list(
1356 &values,
1357 &ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
1358 true,
1359 ))
1360 }
1361 } else {
1362 ScalarValue::Null
1363 }
1364 }
1365 &Type::TIMESTAMPTZ_ARRAY => {
1366 let data = portal.parameter::<Vec<DateTime<FixedOffset>>>(idx, &client_type)?;
1367 if let Some(data) = data {
1368 if let Some(ConcreteDataType::List(list_type)) = &server_type {
1369 match list_type.item_type() {
1370 ConcreteDataType::Timestamp(unit) => match *unit {
1371 TimestampType::Second(_) => {
1372 let values = data
1373 .into_iter()
1374 .map(|ts| {
1375 ScalarValue::TimestampSecond(Some(ts.timestamp()), None)
1376 })
1377 .collect::<Vec<_>>();
1378 ScalarValue::List(ScalarValue::new_list(
1379 &values,
1380 &ArrowDataType::Timestamp(TimeUnit::Second, None),
1381 true,
1382 ))
1383 }
1384 TimestampType::Millisecond(_) => {
1385 let values = data
1386 .into_iter()
1387 .map(|ts| {
1388 ScalarValue::TimestampMillisecond(
1389 Some(ts.timestamp_millis()),
1390 None,
1391 )
1392 })
1393 .collect::<Vec<_>>();
1394 ScalarValue::List(ScalarValue::new_list(
1395 &values,
1396 &ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
1397 true,
1398 ))
1399 }
1400 TimestampType::Microsecond(_) => {
1401 let values = data
1402 .into_iter()
1403 .map(|ts| {
1404 ScalarValue::TimestampMicrosecond(
1405 Some(ts.timestamp_micros()),
1406 None,
1407 )
1408 })
1409 .collect::<Vec<_>>();
1410 ScalarValue::List(ScalarValue::new_list(
1411 &values,
1412 &ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
1413 true,
1414 ))
1415 }
1416 TimestampType::Nanosecond(_) => {
1417 let values = data
1418 .into_iter()
1419 .filter_map(|ts| {
1420 ts.timestamp_nanos_opt().map(|nanos| {
1421 ScalarValue::TimestampNanosecond(Some(nanos), None)
1422 })
1423 })
1424 .collect::<Vec<_>>();
1425 ScalarValue::List(ScalarValue::new_list(
1426 &values,
1427 &ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1428 true,
1429 ))
1430 }
1431 },
1432 _ => {
1433 return Err(invalid_parameter_error(
1434 "invalid_parameter_type",
1435 Some(format!(
1436 "Expected: {}, found: {}",
1437 list_type.item_type(),
1438 client_type
1439 )),
1440 ));
1441 }
1442 }
1443 } else {
1444 let values = data
1446 .into_iter()
1447 .map(|ts| {
1448 ScalarValue::TimestampMillisecond(Some(ts.timestamp_millis()), None)
1449 })
1450 .collect::<Vec<_>>();
1451 ScalarValue::List(ScalarValue::new_list(
1452 &values,
1453 &ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
1454 true,
1455 ))
1456 }
1457 } else {
1458 ScalarValue::Null
1459 }
1460 }
1461 _ => Err(invalid_parameter_error(
1462 "unsupported_parameter_value",
1463 Some(format!("Found type: {}", client_type)),
1464 ))?,
1465 };
1466
1467 results.push(value);
1468 }
1469
1470 Ok(results)
1471}
1472
1473pub(super) fn param_types_to_pg_types(
1474 param_types: &HashMap<String, Option<ConcreteDataType>>,
1475) -> Result<Vec<Type>> {
1476 let param_count = param_types.len();
1477 let mut types = Vec::with_capacity(param_count);
1478 for i in 0..param_count {
1479 if let Some(Some(param_type)) = param_types.get(&format!("${}", i + 1)) {
1480 let pg_type = type_gt_to_pg(param_type)?;
1481 types.push(pg_type);
1482 } else {
1483 types.push(Type::UNKNOWN);
1484 }
1485 }
1486 Ok(types)
1487}
1488
1489#[cfg(test)]
1490mod test {
1491 use std::sync::Arc;
1492
1493 use arrow::array::{
1494 Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
1495 };
1496 use arrow_schema::Field;
1497 use datatypes::schema::{ColumnSchema, Schema};
1498 use datatypes::vectors::{
1499 BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
1500 Int16Vector, Int32Vector, Int64Vector, IntervalDayTimeVector, IntervalMonthDayNanoVector,
1501 IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
1502 TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
1503 };
1504 use pgwire::api::Type;
1505 use pgwire::api::results::{FieldFormat, FieldInfo};
1506 use session::context::QueryContextBuilder;
1507
1508 use super::*;
1509
1510 #[test]
1511 fn test_schema_convert() {
1512 let column_schemas = vec![
1513 ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
1514 ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
1515 ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
1516 ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
1517 ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
1518 ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
1519 ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
1520 ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
1521 ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
1522 ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
1523 ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
1524 ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
1525 ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
1526 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1527 ColumnSchema::new(
1528 "timestamps",
1529 ConcreteDataType::timestamp_millisecond_datatype(),
1530 true,
1531 ),
1532 ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
1533 ColumnSchema::new("times", ConcreteDataType::time_second_datatype(), true),
1534 ColumnSchema::new(
1535 "intervals",
1536 ConcreteDataType::interval_month_day_nano_datatype(),
1537 true,
1538 ),
1539 ];
1540 let pg_field_info = vec![
1541 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1542 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1543 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1544 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1545 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1546 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1547 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1548 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1549 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1550 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1551 FieldInfo::new(
1552 "float32s".into(),
1553 None,
1554 None,
1555 Type::FLOAT4,
1556 FieldFormat::Text,
1557 ),
1558 FieldInfo::new(
1559 "float64s".into(),
1560 None,
1561 None,
1562 Type::FLOAT8,
1563 FieldFormat::Text,
1564 ),
1565 FieldInfo::new(
1566 "binaries".into(),
1567 None,
1568 None,
1569 Type::BYTEA,
1570 FieldFormat::Text,
1571 ),
1572 FieldInfo::new(
1573 "strings".into(),
1574 None,
1575 None,
1576 Type::VARCHAR,
1577 FieldFormat::Text,
1578 ),
1579 FieldInfo::new(
1580 "timestamps".into(),
1581 None,
1582 None,
1583 Type::TIMESTAMP,
1584 FieldFormat::Text,
1585 ),
1586 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1587 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1588 FieldInfo::new(
1589 "intervals".into(),
1590 None,
1591 None,
1592 Type::INTERVAL,
1593 FieldFormat::Text,
1594 ),
1595 ];
1596 let schema = Schema::new(column_schemas);
1597 let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
1598 assert_eq!(fs, pg_field_info);
1599 }
1600
1601 #[test]
1602 fn test_encode_text_format_data() {
1603 let schema = vec![
1604 FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
1605 FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
1606 FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1607 FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
1608 FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
1609 FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
1610 FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
1611 FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
1612 FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
1613 FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
1614 FieldInfo::new(
1615 "float32s".into(),
1616 None,
1617 None,
1618 Type::FLOAT4,
1619 FieldFormat::Text,
1620 ),
1621 FieldInfo::new(
1622 "float64s".into(),
1623 None,
1624 None,
1625 Type::FLOAT8,
1626 FieldFormat::Text,
1627 ),
1628 FieldInfo::new(
1629 "strings".into(),
1630 None,
1631 None,
1632 Type::VARCHAR,
1633 FieldFormat::Text,
1634 ),
1635 FieldInfo::new(
1636 "binaries".into(),
1637 None,
1638 None,
1639 Type::BYTEA,
1640 FieldFormat::Text,
1641 ),
1642 FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
1643 FieldInfo::new("times".into(), None, None, Type::TIME, FieldFormat::Text),
1644 FieldInfo::new(
1645 "timestamps".into(),
1646 None,
1647 None,
1648 Type::TIMESTAMP,
1649 FieldFormat::Text,
1650 ),
1651 FieldInfo::new(
1652 "interval_year_month".into(),
1653 None,
1654 None,
1655 Type::INTERVAL,
1656 FieldFormat::Text,
1657 ),
1658 FieldInfo::new(
1659 "interval_day_time".into(),
1660 None,
1661 None,
1662 Type::INTERVAL,
1663 FieldFormat::Text,
1664 ),
1665 FieldInfo::new(
1666 "interval_month_day_nano".into(),
1667 None,
1668 None,
1669 Type::INTERVAL,
1670 FieldFormat::Text,
1671 ),
1672 FieldInfo::new(
1673 "int_list".into(),
1674 None,
1675 None,
1676 Type::INT8_ARRAY,
1677 FieldFormat::Text,
1678 ),
1679 FieldInfo::new(
1680 "float_list".into(),
1681 None,
1682 None,
1683 Type::FLOAT8_ARRAY,
1684 FieldFormat::Text,
1685 ),
1686 FieldInfo::new(
1687 "string_list".into(),
1688 None,
1689 None,
1690 Type::VARCHAR_ARRAY,
1691 FieldFormat::Text,
1692 ),
1693 FieldInfo::new(
1694 "timestamp_list".into(),
1695 None,
1696 None,
1697 Type::TIMESTAMP_ARRAY,
1698 FieldFormat::Text,
1699 ),
1700 ];
1701
1702 let arrow_schema = arrow_schema::Schema::new(vec![
1703 Field::new("x", DataType::Null, true),
1704 Field::new("x", DataType::Boolean, true),
1705 Field::new("x", DataType::UInt8, true),
1706 Field::new("x", DataType::UInt16, true),
1707 Field::new("x", DataType::UInt32, true),
1708 Field::new("x", DataType::UInt64, true),
1709 Field::new("x", DataType::Int8, true),
1710 Field::new("x", DataType::Int16, true),
1711 Field::new("x", DataType::Int32, true),
1712 Field::new("x", DataType::Int64, true),
1713 Field::new("x", DataType::Float32, true),
1714 Field::new("x", DataType::Float64, true),
1715 Field::new("x", DataType::Utf8, true),
1716 Field::new("x", DataType::Binary, true),
1717 Field::new("x", DataType::Date32, true),
1718 Field::new("x", DataType::Time32(TimeUnit::Second), true),
1719 Field::new("x", DataType::Timestamp(TimeUnit::Second, None), true),
1720 Field::new("x", DataType::Interval(IntervalUnit::YearMonth), true),
1721 Field::new("x", DataType::Interval(IntervalUnit::DayTime), true),
1722 Field::new("x", DataType::Interval(IntervalUnit::MonthDayNano), true),
1723 Field::new(
1724 "x",
1725 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
1726 true,
1727 ),
1728 Field::new(
1729 "x",
1730 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
1731 true,
1732 ),
1733 Field::new(
1734 "x",
1735 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1736 true,
1737 ),
1738 Field::new(
1739 "x",
1740 DataType::List(Arc::new(Field::new(
1741 "item",
1742 DataType::Timestamp(TimeUnit::Second, None),
1743 true,
1744 ))),
1745 true,
1746 ),
1747 ]);
1748
1749 let mut builder = ListBuilder::new(Int64Builder::new());
1750 builder.append_value([Some(1i64), None, Some(2)]);
1751 builder.append_null();
1752 builder.append_value([Some(-1i64), None, Some(-2)]);
1753 let i64_list_array = builder.finish();
1754
1755 let mut builder = ListBuilder::new(Float64Builder::new());
1756 builder.append_value([Some(1.0f64), None, Some(2.0)]);
1757 builder.append_null();
1758 builder.append_value([Some(-1.0f64), None, Some(-2.0)]);
1759 let f64_list_array = builder.finish();
1760
1761 let mut builder = ListBuilder::new(StringBuilder::new());
1762 builder.append_value([Some("a"), None, Some("b")]);
1763 builder.append_null();
1764 builder.append_value([Some("c"), None, Some("d")]);
1765 let string_list_array = builder.finish();
1766
1767 let mut builder = ListBuilder::new(TimestampSecondBuilder::new());
1768 builder.append_value([Some(1i64), None, Some(2)]);
1769 builder.append_null();
1770 builder.append_value([Some(3i64), None, Some(4)]);
1771 let timestamp_list_array = builder.finish();
1772
1773 let values = vec![
1774 Arc::new(NullVector::new(3)) as VectorRef,
1775 Arc::new(BooleanVector::from(vec![Some(true), Some(false), None])),
1776 Arc::new(UInt8Vector::from(vec![Some(u8::MAX), Some(u8::MIN), None])),
1777 Arc::new(UInt16Vector::from(vec![
1778 Some(u16::MAX),
1779 Some(u16::MIN),
1780 None,
1781 ])),
1782 Arc::new(UInt32Vector::from(vec![
1783 Some(u32::MAX),
1784 Some(u32::MIN),
1785 None,
1786 ])),
1787 Arc::new(UInt64Vector::from(vec![
1788 Some(u64::MAX),
1789 Some(u64::MIN),
1790 None,
1791 ])),
1792 Arc::new(Int8Vector::from(vec![Some(i8::MAX), Some(i8::MIN), None])),
1793 Arc::new(Int16Vector::from(vec![
1794 Some(i16::MAX),
1795 Some(i16::MIN),
1796 None,
1797 ])),
1798 Arc::new(Int32Vector::from(vec![
1799 Some(i32::MAX),
1800 Some(i32::MIN),
1801 None,
1802 ])),
1803 Arc::new(Int64Vector::from(vec![
1804 Some(i64::MAX),
1805 Some(i64::MIN),
1806 None,
1807 ])),
1808 Arc::new(Float32Vector::from(vec![
1809 None,
1810 Some(f32::MAX),
1811 Some(f32::MIN),
1812 ])),
1813 Arc::new(Float64Vector::from(vec![
1814 None,
1815 Some(f64::MAX),
1816 Some(f64::MIN),
1817 ])),
1818 Arc::new(StringVector::from(vec![
1819 None,
1820 Some("hello"),
1821 Some("greptime"),
1822 ])),
1823 Arc::new(BinaryVector::from(vec![
1824 None,
1825 Some("hello".as_bytes().to_vec()),
1826 Some("world".as_bytes().to_vec()),
1827 ])),
1828 Arc::new(DateVector::from(vec![Some(1001), None, Some(1)])),
1829 Arc::new(TimeSecondVector::from(vec![Some(1001), None, Some(1)])),
1830 Arc::new(TimestampSecondVector::from(vec![
1831 Some(1000001),
1832 None,
1833 Some(1),
1834 ])),
1835 Arc::new(IntervalYearMonthVector::from(vec![Some(1), None, Some(2)])),
1836 Arc::new(IntervalDayTimeVector::from(vec![
1837 Some(arrow::datatypes::IntervalDayTime::new(1, 1)),
1838 None,
1839 Some(arrow::datatypes::IntervalDayTime::new(2, 2)),
1840 ])),
1841 Arc::new(IntervalMonthDayNanoVector::from(vec![
1842 Some(arrow::datatypes::IntervalMonthDayNano::new(1, 1, 10)),
1843 None,
1844 Some(arrow::datatypes::IntervalMonthDayNano::new(2, 2, 20)),
1845 ])),
1846 Arc::new(ListVector::from(i64_list_array)),
1847 Arc::new(ListVector::from(f64_list_array)),
1848 Arc::new(ListVector::from(string_list_array)),
1849 Arc::new(ListVector::from(timestamp_list_array)),
1850 ];
1851 let record_batch =
1852 RecordBatch::new(Arc::new(arrow_schema.try_into().unwrap()), values).unwrap();
1853
1854 let query_context = QueryContextBuilder::default()
1855 .configuration_parameter(Default::default())
1856 .build()
1857 .into();
1858 let schema = Arc::new(schema);
1859
1860 let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
1861 .filter_map(|x| x.ok())
1862 .collect::<Vec<_>>();
1863 assert_eq!(rows.len(), 3);
1864 for row in rows {
1865 assert_eq!(row.field_count, schema.len() as i16);
1866 }
1867 }
1868
1869 #[test]
1870 fn test_invalid_parameter() {
1871 let msg = "invalid_parameter_count";
1873 let error = invalid_parameter_error(msg, None);
1874 if let PgWireError::UserError(value) = error {
1875 assert_eq!("ERROR", value.severity);
1876 assert_eq!("22023", value.code);
1877 assert_eq!(msg, value.message);
1878 } else {
1879 panic!("test_invalid_parameter failed");
1880 }
1881 }
1882}