1use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use common_query::{Output, OutputData};
21use common_recordbatch::RecordBatch;
22use common_recordbatch::error::Result as RecordBatchResult;
23use common_telemetry::{debug, tracing};
24use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
25use datafusion_common::ParamValues;
26use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
27use datatypes::prelude::ConcreteDataType;
28use datatypes::schema::SchemaRef;
29use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
30use pgwire::api::portal::{Format, Portal};
31use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
32use pgwire::api::results::{
33 CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder,
34 DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
35};
36use pgwire::api::stmt::{QueryParser, StoredStatement};
37use pgwire::api::{ClientInfo, ErrorHandler, Type};
38use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
39use pgwire::messages::PgWireBackendMessage;
40use pgwire::messages::copy::CopyData;
41use pgwire::messages::data::DataRow;
42use query::planner::DfLogicalPlanner;
43use query::query_engine::DescribeResult;
44use session::Session;
45use session::context::QueryContextRef;
46use snafu::ResultExt;
47use sql::dialect::PostgreSqlDialect;
48use sql::parser::{ParseOptions, ParserContext};
49use sql::statements::statement::Statement;
50
51use crate::SqlPlan;
52use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
53use crate::postgres::types::*;
54use crate::postgres::utils::convert_err;
55use crate::postgres::{PostgresServerHandlerInner, fixtures};
56use crate::query_handler::sql::ServerSqlQueryHandlerRef;
57
58#[async_trait]
59impl SimpleQueryHandler for PostgresServerHandlerInner {
60 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
61 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
62 where
63 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
64 C::Error: Debug,
65 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
66 {
67 let query_ctx = self.session.new_query_context();
68 let db = query_ctx.get_db_string();
69 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
70 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
71 .start_timer();
72
73 if query.is_empty() {
74 return Ok(vec![Response::EmptyQuery]);
76 }
77
78 let parsed_query = self.query_parser.compatibility_parser.parse(query);
79
80 let query = if let Ok(statements) = &parsed_query {
81 statements
82 .iter()
83 .map(|s| s.to_string())
84 .collect::<Vec<_>>()
85 .join(";")
86 } else {
87 query.to_string()
88 };
89
90 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
91 send_warning_opt(client, query_ctx).await?;
92 Ok(resps)
93 } else {
94 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
95
96 let mut results = Vec::with_capacity(outputs.len());
97
98 let statements = parsed_query.ok();
99 for (idx, output) in outputs.into_iter().enumerate() {
100 let copy_format = statements
101 .as_ref()
102 .and_then(|stmts| stmts.get(idx))
103 .and_then(check_copy_to_stdout);
104 let resp = if let Some(format) = ©_format {
105 output_to_copy_response(query_ctx.clone(), output, format)?
106 } else {
107 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?
108 };
109 results.push(resp);
110 }
111
112 send_warning_opt(client, query_ctx).await?;
113 Ok(results)
114 }
115 }
116}
117
118async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
119where
120 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
121 C::Error: Debug,
122 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
123{
124 if let Some(warning) = query_context.warning() {
125 client
126 .feed(PgWireBackendMessage::NoticeResponse(
127 ErrorInfo::new(
128 PgErrorSeverity::Warning.to_string(),
129 PgErrorCode::Ec01000.code(),
130 warning.clone(),
131 )
132 .into(),
133 ))
134 .await?;
135 }
136
137 Ok(())
138}
139
140pub(crate) fn output_to_query_response(
141 query_ctx: QueryContextRef,
142 output: Result<Output>,
143 field_format: &Format,
144) -> PgWireResult<Response> {
145 match output {
146 Ok(o) => match o.data {
147 OutputData::AffectedRows(rows) => {
148 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
149 }
150 OutputData::Stream(record_stream) => {
151 let schema = record_stream.schema();
152 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
153 }
154 OutputData::RecordBatches(recordbatches) => {
155 let schema = recordbatches.schema();
156 recordbatches_to_query_response(
157 query_ctx,
158 recordbatches.as_stream(),
159 schema,
160 field_format,
161 )
162 }
163 },
164 Err(e) => Err(convert_err(e)),
165 }
166}
167
168type RowStream<T> = Pin<Box<dyn Stream<Item = PgWireResult<T>> + Send + Unpin>>;
169
170fn recordbatches_to_query_response<S>(
171 query_ctx: QueryContextRef,
172 recordbatches_stream: S,
173 schema: SchemaRef,
174 field_format: &Format,
175) -> PgWireResult<Response>
176where
177 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
178{
179 let format_options = format_options_from_query_ctx(&query_ctx);
180 let pg_schema = Arc::new(
181 schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
182 );
183
184 let encoder = DataRowEncoder::new(pg_schema.clone());
185 let row_stream = RecordBatchRowStream::new(
186 query_ctx.clone(),
187 pg_schema.clone(),
188 schema.clone(),
189 recordbatches_stream,
190 encoder,
191 );
192
193 let data_row_stream: RowStream<DataRow> = Box::pin(
194 row_stream
195 .map(move |result| match result {
196 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<DataRow>,
197 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<DataRow>,
198 })
199 .flatten(),
200 );
201
202 Ok(Response::Query(QueryResponse::new(
203 pg_schema,
204 data_row_stream,
205 )))
206}
207
208pub(crate) fn output_to_copy_response(
209 query_ctx: QueryContextRef,
210 output: Result<Output>,
211 format: &str,
212) -> PgWireResult<Response> {
213 match output {
214 Ok(o) => match o.data {
215 OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
216 "ERROR".to_string(),
217 "42601".to_string(),
218 "COPY cannot be used with non-query statements".to_string(),
219 )))),
220 OutputData::Stream(record_stream) => {
221 let schema = record_stream.schema();
222 recordbatches_to_copy_response(query_ctx, record_stream, schema, format)
223 }
224 OutputData::RecordBatches(recordbatches) => {
225 let schema = recordbatches.schema();
226 recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format)
227 }
228 },
229 Err(e) => Err(convert_err(e)),
230 }
231}
232
233fn recordbatches_to_copy_response<S>(
234 query_ctx: QueryContextRef,
235 recordbatches_stream: S,
236 schema: SchemaRef,
237 format: &str,
238) -> PgWireResult<Response>
239where
240 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
241{
242 let format_options = format_options_from_query_ctx(&query_ctx);
243 let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options))
244 .map_err(convert_err)?;
245
246 let copy_format = match format.to_lowercase().as_str() {
247 "binary" => 1,
248 _ => 0,
249 };
250
251 let pg_schema = Arc::new(pg_fields);
252 let num_columns = pg_schema.len();
253
254 let copy_encoder = match format.to_lowercase().as_str() {
255 "csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()),
256 "binary" => CopyEncoder::new_binary(pg_schema.clone()),
257 _ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()),
258 };
259
260 let row_stream = RecordBatchRowStream::new(
261 query_ctx.clone(),
262 pg_schema.clone(),
263 schema.clone(),
264 recordbatches_stream,
265 copy_encoder,
266 );
267
268 let copy_stream: RowStream<CopyData> = Box::pin(
269 row_stream
270 .map(move |result| match result {
271 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<CopyData>,
272 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<CopyData>,
273 })
274 .flatten(),
275 );
276
277 Ok(Response::CopyOut(CopyResponse::new(
278 copy_format,
279 num_columns,
280 copy_stream,
281 )))
282}
283
284pub struct DefaultQueryParser {
285 query_handler: ServerSqlQueryHandlerRef,
286 session: Arc<Session>,
287 compatibility_parser: PostgresCompatibilityParser,
288}
289
290impl DefaultQueryParser {
291 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
292 DefaultQueryParser {
293 query_handler,
294 session,
295 compatibility_parser: PostgresCompatibilityParser::new(),
296 }
297 }
298}
299
300#[derive(Clone, Debug)]
302pub struct PgSqlPlan {
303 plan: SqlPlan,
304 copy_to_stdout_format: Option<String>,
305}
306
307#[async_trait]
308impl QueryParser for DefaultQueryParser {
309 type Statement = PgSqlPlan;
310
311 async fn parse_sql<C>(
312 &self,
313 _client: &C,
314 sql: &str,
315 _types: &[Option<Type>],
316 ) -> PgWireResult<Self::Statement> {
317 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
318 let query_ctx = self.session.new_query_context();
319
320 if sql.is_empty() || fixtures::matches(sql) {
322 return Ok(PgSqlPlan {
323 plan: SqlPlan {
324 query: sql.to_owned(),
325 statement: None,
326 plan: None,
327 schema: None,
328 },
329 copy_to_stdout_format: None,
330 });
331 }
332
333 let parsed_statements = self.compatibility_parser.parse(sql);
334 let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements {
335 let first_stmt = statements.remove(0);
336 let format = check_copy_to_stdout(&first_stmt);
337 (first_stmt.to_string(), format)
338 } else {
339 (sql.to_string(), None)
342 };
343
344 let mut stmts = ParserContext::create_with_dialect(
345 &sql,
346 &PostgreSqlDialect {},
347 ParseOptions::default(),
348 )
349 .map_err(convert_err)?;
350 if stmts.len() != 1 {
351 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
352 PgErrorCode::Ec42P14,
353 ))))
354 } else {
355 let stmt = stmts.remove(0);
356
357 let describe_result = self
358 .query_handler
359 .do_describe(stmt.clone(), query_ctx)
360 .await
361 .map_err(convert_err)?;
362
363 let (plan, schema) = if let Some(DescribeResult {
364 logical_plan,
365 schema,
366 }) = describe_result
367 {
368 (Some(logical_plan), Some(schema))
369 } else {
370 (None, None)
371 };
372
373 Ok(PgSqlPlan {
374 plan: SqlPlan {
375 query: sql.clone(),
376 statement: Some(stmt),
377 plan,
378 schema,
379 },
380 copy_to_stdout_format,
381 })
382 }
383 }
384
385 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
386 Err(PgWireError::ApiError(
389 "get_parameter_types is not expected to be called".into(),
390 ))
391 }
392
393 fn get_result_schema(
394 &self,
395 _stmt: &Self::Statement,
396 _column_format: Option<&Format>,
397 ) -> PgWireResult<Vec<FieldInfo>> {
398 Err(PgWireError::ApiError(
401 "get_result_schema is not expected to be called".into(),
402 ))
403 }
404}
405
406#[async_trait]
407impl ExtendedQueryHandler for PostgresServerHandlerInner {
408 type Statement = PgSqlPlan;
409 type QueryParser = DefaultQueryParser;
410
411 fn query_parser(&self) -> Arc<Self::QueryParser> {
412 self.query_parser.clone()
413 }
414
415 async fn do_query<C>(
416 &self,
417 client: &mut C,
418 portal: &Portal<Self::Statement>,
419 _max_rows: usize,
420 ) -> PgWireResult<Response>
421 where
422 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
423 C::Error: Debug,
424 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
425 {
426 let query_ctx = self.session.new_query_context();
427 let db = query_ctx.get_db_string();
428 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
429 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
430 .start_timer();
431
432 let pg_sql_plan = &portal.statement.statement;
433 let sql_plan = &pg_sql_plan.plan;
434
435 if sql_plan.query.is_empty() {
436 return Ok(Response::EmptyQuery);
438 }
439
440 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
441 send_warning_opt(client, query_ctx).await?;
442 return Ok(resps.remove(0));
444 }
445
446 let output = if let Some(plan) = &sql_plan.plan {
447 let values = parameters_to_scalar_values(plan, portal)?;
448 let plan = plan
449 .clone()
450 .replace_params_with_values(&ParamValues::List(
451 values.into_iter().map(Into::into).collect(),
452 ))
453 .context(DataFusionSnafu)
454 .map_err(convert_err)?;
455 self.query_handler
456 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
457 .await
458 } else {
459 let mut sql = sql_plan.query.clone();
463 for i in 0..portal.parameter_len() {
464 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
465 }
466
467 self.query_handler
468 .do_query(&sql, query_ctx.clone())
469 .await
470 .remove(0)
471 };
472
473 send_warning_opt(client, query_ctx.clone()).await?;
474
475 if let Some(format) = &pg_sql_plan.copy_to_stdout_format {
476 output_to_copy_response(query_ctx, output, format)
477 } else {
478 output_to_query_response(query_ctx, output, &portal.result_column_format)
479 }
480 }
481
482 async fn do_describe_statement<C>(
483 &self,
484 _client: &mut C,
485 stmt: &StoredStatement<Self::Statement>,
486 ) -> PgWireResult<DescribeStatementResponse>
487 where
488 C: ClientInfo + Unpin + Send + Sync,
489 {
490 let sql_plan = &stmt.statement.plan;
491 let provided_param_types = &stmt.parameter_types;
493 let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
494 let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
495 .context(InferParameterTypesSnafu)
496 .map_err(convert_err)?
497 .into_iter()
498 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
499 .collect();
500
501 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
502
503 Some(types)
504 } else {
505 None
506 };
507
508 let param_count = if provided_param_types.is_empty() {
509 server_inferenced_types
510 .as_ref()
511 .map(|types| types.len())
512 .unwrap_or(0)
513 } else {
514 provided_param_types.len()
515 };
516
517 let param_types = (0..param_count)
518 .map(|i| {
519 let client_type = provided_param_types.get(i);
520 match client_type {
522 Some(Some(client_type)) => client_type.clone(),
523 _ => server_inferenced_types
524 .as_ref()
525 .and_then(|types| types.get(i).cloned())
526 .unwrap_or(Type::UNKNOWN),
527 }
528 })
529 .collect::<Vec<_>>();
530
531 if let Some(schema) = &sql_plan.schema {
532 schema_to_pg(schema, &Format::UnifiedBinary, None)
533 .map(|fields| DescribeStatementResponse::new(param_types, fields))
534 .map_err(convert_err)
535 } else {
536 if let Some(mut resp) =
537 fixtures::process(&sql_plan.query, self.session.new_query_context())
538 && let Response::Query(query_response) = resp.remove(0)
539 {
540 return Ok(DescribeStatementResponse::new(
541 param_types,
542 (*query_response.row_schema()).clone(),
543 ));
544 }
545
546 Ok(DescribeStatementResponse::new(param_types, vec![]))
547 }
548 }
549
550 async fn do_describe_portal<C>(
551 &self,
552 _client: &mut C,
553 portal: &Portal<Self::Statement>,
554 ) -> PgWireResult<DescribePortalResponse>
555 where
556 C: ClientInfo + Unpin + Send + Sync,
557 {
558 let sql_plan = &portal.statement.statement.plan;
559 let format = &portal.result_column_format;
560
561 match sql_plan.statement.as_ref() {
562 Some(Statement::Query(_)) => {
563 if let Some(schema) = &sql_plan.schema {
565 schema_to_pg(schema, format, None)
566 .map(DescribePortalResponse::new)
567 .map_err(convert_err)
568 } else {
569 Ok(DescribePortalResponse::new(vec![]))
571 }
572 }
573 Some(Statement::ShowCreateDatabase(_))
576 | Some(Statement::ShowCreateTable(_))
577 | Some(Statement::ShowCreateFlow(_))
578 | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
579 FieldInfo::new(
580 "name".to_string(),
581 None,
582 None,
583 Type::TEXT,
584 format.format_for(0),
585 ),
586 FieldInfo::new(
587 "create_statement".to_string(),
588 None,
589 None,
590 Type::TEXT,
591 format.format_for(1),
592 ),
593 ])),
594 Some(Statement::ShowTables(_))
596 | Some(Statement::ShowFlows(_))
597 | Some(Statement::ShowViews(_)) => {
598 Ok(DescribePortalResponse::new(vec![FieldInfo::new(
599 "name".to_string(),
600 None,
601 None,
602 Type::TEXT,
603 format.format_for(0),
604 )]))
605 }
606 _ => {
609 if let Some(mut resp) =
611 fixtures::process(&sql_plan.query, self.session.new_query_context())
612 && let Response::Query(query_response) = resp.remove(0)
613 {
614 Ok(DescribePortalResponse::new(
615 (*query_response.row_schema()).clone(),
616 ))
617 } else {
618 Ok(DescribePortalResponse::new(vec![]))
620 }
621 }
622 }
623 }
624}
625
626impl ErrorHandler for PostgresServerHandlerInner {
627 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
628 where
629 C: ClientInfo,
630 {
631 debug!("Postgres interface error {}", error)
632 }
633}
634
635fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option<String> {
636 if let SqlParserStatement::Copy {
637 target, options, ..
638 } = statement
639 && matches!(target, CopyTarget::Stdout)
640 {
641 for opt in options {
642 if let CopyOption::Format(format_ident) = opt {
643 return Some(format_ident.value.to_lowercase());
644 }
645 }
646 return Some("txt".to_string());
647 }
648
649 None
650}
651
652#[cfg(test)]
653mod tests {
654 use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
655
656 use super::*;
657
658 fn parse_copy_statement(sql: &str) -> SqlParserStatement {
659 let parser = PostgresCompatibilityParser::new();
660 let statements = parser.parse(sql).unwrap();
661 statements.into_iter().next().unwrap()
662 }
663
664 #[test]
665 fn test_check_copy_out_with_csv_format() {
666 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)");
667 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
668 }
669
670 #[test]
671 fn test_check_copy_out_with_txt_format() {
672 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)");
673 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
674 }
675
676 #[test]
677 fn test_check_copy_out_with_binary_format() {
678 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)");
679 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
680 }
681
682 #[test]
683 fn test_check_copy_out_without_format() {
684 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT");
685 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
686 }
687
688 #[test]
689 fn test_check_copy_out_to_file() {
690 let statement =
691 parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)");
692 assert_eq!(check_copy_to_stdout(&statement), None);
693 }
694
695 #[test]
696 fn test_check_copy_out_case_insensitive() {
697 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)");
698 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
699
700 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)");
701 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
702 }
703
704 #[test]
705 fn test_check_copy_out_with_multiple_options() {
706 let statement = parse_copy_statement(
707 "COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)",
708 );
709 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
710
711 let statement = parse_copy_statement(
712 "COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)",
713 );
714 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
715 }
716}