1use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use common_query::{Output, OutputData};
21use common_recordbatch::RecordBatch;
22use common_recordbatch::error::Result as RecordBatchResult;
23use common_telemetry::{debug, info, tracing};
24use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
25use datafusion_common::ParamValues;
26use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
27use datatypes::prelude::ConcreteDataType;
28use datatypes::schema::SchemaRef;
29use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
30use pgwire::api::portal::{Format, Portal};
31use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
32use pgwire::api::results::{
33 CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder,
34 DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
35};
36use pgwire::api::stmt::{QueryParser, StoredStatement};
37use pgwire::api::{ClientInfo, ErrorHandler, Type};
38use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
39use pgwire::messages::PgWireBackendMessage;
40use pgwire::messages::copy::CopyData;
41use pgwire::messages::data::DataRow;
42use query::planner::DfLogicalPlanner;
43use query::query_engine::DescribeResult;
44use session::Session;
45use session::context::QueryContextRef;
46use snafu::ResultExt;
47use sql::dialect::PostgreSqlDialect;
48use sql::parser::{ParseOptions, ParserContext};
49use sql::statements::statement::Statement;
50
51use crate::SqlPlan;
52use crate::error::{DataFusionSnafu, InferParameterTypesSnafu, Result};
53use crate::postgres::types::*;
54use crate::postgres::utils::convert_err;
55use crate::postgres::{PostgresServerHandlerInner, fixtures};
56use crate::query_handler::sql::ServerSqlQueryHandlerRef;
57
58#[async_trait]
59impl SimpleQueryHandler for PostgresServerHandlerInner {
60 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
61 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
62 where
63 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
64 C::Error: Debug,
65 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
66 {
67 let query_ctx = self.session.new_query_context();
68 let db = query_ctx.get_db_string();
69 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
70 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
71 .start_timer();
72
73 if query.is_empty() {
74 return Ok(vec![Response::EmptyQuery]);
76 }
77
78 let parsed_query = self.query_parser.compatibility_parser.parse(query);
79
80 let query = if let Ok(statements) = &parsed_query {
81 statements
82 .iter()
83 .map(|s| s.to_string())
84 .collect::<Vec<_>>()
85 .join(";")
86 } else {
87 query.to_string()
88 };
89
90 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
91 send_warning_opt(client, query_ctx).await?;
92 Ok(resps)
93 } else {
94 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
95
96 let mut results = Vec::with_capacity(outputs.len());
97
98 let statements = parsed_query.ok();
99 for (idx, output) in outputs.into_iter().enumerate() {
100 let copy_format = statements
101 .as_ref()
102 .and_then(|stmts| stmts.get(idx))
103 .and_then(check_copy_to_stdout);
104 let resp = if let Some(format) = ©_format {
105 output_to_copy_response(query_ctx.clone(), output, format)?
106 } else {
107 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?
108 };
109 results.push(resp);
110 }
111
112 send_warning_opt(client, query_ctx).await?;
113 Ok(results)
114 }
115 }
116}
117
118async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
119where
120 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
121 C::Error: Debug,
122 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
123{
124 if let Some(warning) = query_context.warning() {
125 client
126 .feed(PgWireBackendMessage::NoticeResponse(
127 ErrorInfo::new(
128 PgErrorSeverity::Warning.to_string(),
129 PgErrorCode::Ec01000.code(),
130 warning.clone(),
131 )
132 .into(),
133 ))
134 .await?;
135 }
136
137 Ok(())
138}
139
140pub(crate) fn output_to_query_response(
141 query_ctx: QueryContextRef,
142 output: Result<Output>,
143 field_format: &Format,
144) -> PgWireResult<Response> {
145 match output {
146 Ok(o) => match o.data {
147 OutputData::AffectedRows(rows) => {
148 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
149 }
150 OutputData::Stream(record_stream) => {
151 let schema = record_stream.schema();
152 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
153 }
154 OutputData::RecordBatches(recordbatches) => {
155 let schema = recordbatches.schema();
156 recordbatches_to_query_response(
157 query_ctx,
158 recordbatches.as_stream(),
159 schema,
160 field_format,
161 )
162 }
163 },
164 Err(e) => Err(convert_err(e)),
165 }
166}
167
168type RowStream<T> = Pin<Box<dyn Stream<Item = PgWireResult<T>> + Send + Unpin>>;
169
170fn recordbatches_to_query_response<S>(
171 query_ctx: QueryContextRef,
172 recordbatches_stream: S,
173 schema: SchemaRef,
174 field_format: &Format,
175) -> PgWireResult<Response>
176where
177 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
178{
179 let format_options = format_options_from_query_ctx(&query_ctx);
180 let pg_schema = Arc::new(
181 schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
182 );
183
184 let encoder = DataRowEncoder::new(pg_schema.clone());
185 let row_stream = RecordBatchRowStream::new(
186 query_ctx.clone(),
187 pg_schema.clone(),
188 schema.clone(),
189 recordbatches_stream,
190 encoder,
191 );
192
193 let data_row_stream: RowStream<DataRow> = Box::pin(
194 row_stream
195 .map(move |result| match result {
196 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<DataRow>,
197 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<DataRow>,
198 })
199 .flatten(),
200 );
201
202 Ok(Response::Query(QueryResponse::new(
203 pg_schema,
204 data_row_stream,
205 )))
206}
207
208pub(crate) fn output_to_copy_response(
209 query_ctx: QueryContextRef,
210 output: Result<Output>,
211 format: &str,
212) -> PgWireResult<Response> {
213 match output {
214 Ok(o) => match o.data {
215 OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
216 "ERROR".to_string(),
217 "42601".to_string(),
218 "COPY cannot be used with non-query statements".to_string(),
219 )))),
220 OutputData::Stream(record_stream) => {
221 let schema = record_stream.schema();
222 recordbatches_to_copy_response(query_ctx, record_stream, schema, format)
223 }
224 OutputData::RecordBatches(recordbatches) => {
225 let schema = recordbatches.schema();
226 recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format)
227 }
228 },
229 Err(e) => Err(convert_err(e)),
230 }
231}
232
233fn recordbatches_to_copy_response<S>(
234 query_ctx: QueryContextRef,
235 recordbatches_stream: S,
236 schema: SchemaRef,
237 format: &str,
238) -> PgWireResult<Response>
239where
240 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
241{
242 let format_options = format_options_from_query_ctx(&query_ctx);
243 let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options))
244 .map_err(convert_err)?;
245
246 let copy_format = match format.to_lowercase().as_str() {
247 "binary" => 1,
248 _ => 0,
249 };
250
251 let pg_schema = Arc::new(pg_fields);
252 let num_columns = pg_schema.len();
253
254 let copy_encoder = match format.to_lowercase().as_str() {
255 "csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()),
256 "binary" => CopyEncoder::new_binary(pg_schema.clone()),
257 _ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()),
258 };
259
260 let row_stream = RecordBatchRowStream::new(
261 query_ctx.clone(),
262 pg_schema.clone(),
263 schema.clone(),
264 recordbatches_stream,
265 copy_encoder,
266 );
267
268 let copy_stream: RowStream<CopyData> = Box::pin(
269 row_stream
270 .map(move |result| match result {
271 Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<CopyData>,
272 Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<CopyData>,
273 })
274 .flatten(),
275 );
276
277 Ok(Response::CopyOut(CopyResponse::new(
278 copy_format,
279 num_columns,
280 copy_stream,
281 )))
282}
283
284pub struct DefaultQueryParser {
285 query_handler: ServerSqlQueryHandlerRef,
286 session: Arc<Session>,
287 compatibility_parser: PostgresCompatibilityParser,
288}
289
290impl DefaultQueryParser {
291 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
292 DefaultQueryParser {
293 query_handler,
294 session,
295 compatibility_parser: PostgresCompatibilityParser::new(),
296 }
297 }
298}
299
300#[derive(Clone, Debug)]
302pub struct PgSqlPlan {
303 pub(crate) plan: SqlPlan,
304 pub(crate) copy_to_stdout_format: Option<String>,
305}
306
307#[async_trait]
308impl QueryParser for DefaultQueryParser {
309 type Statement = PgSqlPlan;
310
311 async fn parse_sql<C>(
312 &self,
313 _client: &C,
314 sql: &str,
315 _types: &[Option<Type>],
316 ) -> PgWireResult<Self::Statement> {
317 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
318 let query_ctx = self.session.new_query_context();
319
320 if sql.is_empty() || fixtures::matches(sql) {
322 return Ok(PgSqlPlan {
323 plan: SqlPlan {
324 query: sql.to_owned(),
325 statement: None,
326 plan: None,
327 schema: None,
328 },
329 copy_to_stdout_format: None,
330 });
331 }
332
333 let parsed_statements = self.compatibility_parser.parse(sql);
334 let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements {
335 let first_stmt = statements.remove(0);
336 let format = check_copy_to_stdout(&first_stmt);
337 (first_stmt.to_string(), format)
338 } else {
339 (sql.to_string(), None)
342 };
343
344 let mut stmts = ParserContext::create_with_dialect(
345 &sql,
346 &PostgreSqlDialect {},
347 ParseOptions::default(),
348 )
349 .map_err(convert_err)?;
350 if stmts.len() != 1 {
351 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
352 PgErrorCode::Ec42P14,
353 ))))
354 } else {
355 let stmt = stmts.remove(0);
356
357 let describe_result = self
358 .query_handler
359 .do_describe(stmt.clone(), query_ctx)
360 .await
361 .map_err(convert_err)?;
362
363 let (plan, schema) = if let Some(DescribeResult {
364 logical_plan,
365 schema,
366 }) = describe_result
367 {
368 (Some(logical_plan), Some(schema))
369 } else {
370 (None, None)
371 };
372
373 Ok(PgSqlPlan {
374 plan: SqlPlan {
375 query: sql.clone(),
376 statement: Some(stmt),
377 plan,
378 schema,
379 },
380 copy_to_stdout_format,
381 })
382 }
383 }
384
385 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
386 Err(PgWireError::ApiError(
389 "get_parameter_types is not expected to be called".into(),
390 ))
391 }
392
393 fn get_result_schema(
394 &self,
395 _stmt: &Self::Statement,
396 _column_format: Option<&Format>,
397 ) -> PgWireResult<Vec<FieldInfo>> {
398 Err(PgWireError::ApiError(
401 "get_result_schema is not expected to be called".into(),
402 ))
403 }
404}
405
406#[async_trait]
407impl ExtendedQueryHandler for PostgresServerHandlerInner {
408 type Statement = PgSqlPlan;
409 type QueryParser = DefaultQueryParser;
410
411 fn query_parser(&self) -> Arc<Self::QueryParser> {
412 self.query_parser.clone()
413 }
414
415 async fn do_query<C>(
416 &self,
417 client: &mut C,
418 portal: &Portal<Self::Statement>,
419 _max_rows: usize,
420 ) -> PgWireResult<Response>
421 where
422 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
423 C::Error: Debug,
424 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
425 {
426 let query_ctx = self.session.new_query_context();
427 let db = query_ctx.get_db_string();
428 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
429 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
430 .start_timer();
431
432 let pg_sql_plan = &portal.statement.statement;
433 let sql_plan = &pg_sql_plan.plan;
434
435 if sql_plan.query.is_empty() {
436 return Ok(Response::EmptyQuery);
438 }
439
440 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
441 send_warning_opt(client, query_ctx).await?;
442 return Ok(resps.remove(0));
444 }
445
446 let output = if let Some(plan) = &sql_plan.plan {
447 let values = parameters_to_scalar_values(plan, portal)?;
448 let plan = plan
449 .clone()
450 .replace_params_with_values(&ParamValues::List(
451 values.into_iter().map(Into::into).collect(),
452 ))
453 .context(DataFusionSnafu)
454 .map_err(convert_err)?;
455 self.query_handler
456 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
457 .await
458 } else {
459 self.query_handler
465 .do_query(&sql_plan.query, query_ctx.clone())
466 .await
467 .remove(0)
468 };
469
470 send_warning_opt(client, query_ctx.clone()).await?;
471
472 if let Some(format) = &pg_sql_plan.copy_to_stdout_format {
473 output_to_copy_response(query_ctx, output, format)
474 } else {
475 output_to_query_response(query_ctx, output, &portal.result_column_format)
476 }
477 }
478
479 async fn do_describe_statement<C>(
480 &self,
481 _client: &mut C,
482 stmt: &StoredStatement<Self::Statement>,
483 ) -> PgWireResult<DescribeStatementResponse>
484 where
485 C: ClientInfo + Unpin + Send + Sync,
486 {
487 let sql_plan = &stmt.statement.plan;
488 let provided_param_types = &stmt.parameter_types;
490 let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
491 let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
492 .context(InferParameterTypesSnafu)
493 .map_err(convert_err)?
494 .into_iter()
495 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
496 .collect();
497
498 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
499
500 Some(types)
501 } else {
502 None
503 };
504
505 let param_count = if provided_param_types.is_empty() {
506 server_inferenced_types
507 .as_ref()
508 .map(|types| types.len())
509 .unwrap_or(0)
510 } else {
511 provided_param_types.len()
512 };
513
514 let param_types = (0..param_count)
515 .map(|i| {
516 let client_type = provided_param_types.get(i);
517 match client_type {
519 Some(Some(client_type)) => client_type.clone(),
520 _ => server_inferenced_types
521 .as_ref()
522 .and_then(|types| types.get(i).cloned())
523 .unwrap_or(Type::UNKNOWN),
524 }
525 })
526 .collect::<Vec<_>>();
527
528 if let Some(schema) = &sql_plan.schema {
529 schema_to_pg(schema, &Format::UnifiedText, None)
530 .map(|fields| DescribeStatementResponse::new(param_types, fields))
531 .map_err(convert_err)
532 } else {
533 if let Some(mut resp) =
534 fixtures::process(&sql_plan.query, self.session.new_query_context())
535 && let Response::Query(query_response) = resp.remove(0)
536 {
537 return Ok(DescribeStatementResponse::new(
538 param_types,
539 (*query_response.row_schema()).clone(),
540 ));
541 }
542
543 Ok(DescribeStatementResponse::new(param_types, vec![]))
544 }
545 }
546
547 async fn do_describe_portal<C>(
548 &self,
549 _client: &mut C,
550 portal: &Portal<Self::Statement>,
551 ) -> PgWireResult<DescribePortalResponse>
552 where
553 C: ClientInfo + Unpin + Send + Sync,
554 {
555 let sql_plan = &portal.statement.statement.plan;
556 let format = &portal.result_column_format;
557
558 match sql_plan.statement.as_ref() {
559 Some(Statement::Query(_)) => {
560 if let Some(schema) = &sql_plan.schema {
562 schema_to_pg(schema, format, None)
563 .map(DescribePortalResponse::new)
564 .map_err(convert_err)
565 } else {
566 Ok(DescribePortalResponse::new(vec![]))
568 }
569 }
570 Some(Statement::ShowCreateDatabase(_))
573 | Some(Statement::ShowCreateTable(_))
574 | Some(Statement::ShowCreateFlow(_))
575 | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
576 FieldInfo::new(
577 "name".to_string(),
578 None,
579 None,
580 Type::TEXT,
581 format.format_for(0),
582 ),
583 FieldInfo::new(
584 "create_statement".to_string(),
585 None,
586 None,
587 Type::TEXT,
588 format.format_for(1),
589 ),
590 ])),
591 Some(Statement::ShowTables(_))
593 | Some(Statement::ShowFlows(_))
594 | Some(Statement::ShowViews(_)) => {
595 Ok(DescribePortalResponse::new(vec![FieldInfo::new(
596 "name".to_string(),
597 None,
598 None,
599 Type::TEXT,
600 format.format_for(0),
601 )]))
602 }
603 _ => {
606 if let Some(mut resp) =
608 fixtures::process(&sql_plan.query, self.session.new_query_context())
609 && let Response::Query(query_response) = resp.remove(0)
610 {
611 Ok(DescribePortalResponse::new(
612 (*query_response.row_schema()).clone(),
613 ))
614 } else {
615 Ok(DescribePortalResponse::new(vec![]))
617 }
618 }
619 }
620 }
621}
622
623impl ErrorHandler for PostgresServerHandlerInner {
624 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
625 where
626 C: ClientInfo,
627 {
628 match error {
629 PgWireError::IoError(e) => debug!("Postgres client disconnected: {}", e),
630 _ => info!("Postgres interface error: {}", error),
631 }
632 }
633}
634
635fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option<String> {
636 if let SqlParserStatement::Copy {
637 target, options, ..
638 } = statement
639 && matches!(target, CopyTarget::Stdout)
640 {
641 for opt in options {
642 if let CopyOption::Format(format_ident) = opt {
643 return Some(format_ident.value.to_lowercase());
644 }
645 }
646 return Some("txt".to_string());
647 }
648
649 None
650}
651
652#[cfg(test)]
653mod tests {
654 use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
655
656 use super::*;
657
658 fn parse_copy_statement(sql: &str) -> SqlParserStatement {
659 let parser = PostgresCompatibilityParser::new();
660 let statements = parser.parse(sql).unwrap();
661 statements.into_iter().next().unwrap()
662 }
663
664 #[test]
665 fn test_check_copy_out_with_csv_format() {
666 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)");
667 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
668 }
669
670 #[test]
671 fn test_check_copy_out_with_txt_format() {
672 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)");
673 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
674 }
675
676 #[test]
677 fn test_check_copy_out_with_binary_format() {
678 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)");
679 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
680 }
681
682 #[test]
683 fn test_check_copy_out_without_format() {
684 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT");
685 assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
686 }
687
688 #[test]
689 fn test_check_copy_out_to_file() {
690 let statement =
691 parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)");
692 assert_eq!(check_copy_to_stdout(&statement), None);
693 }
694
695 #[test]
696 fn test_check_copy_out_case_insensitive() {
697 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)");
698 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
699
700 let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)");
701 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
702 }
703
704 #[test]
705 fn test_check_copy_out_with_multiple_options() {
706 let statement = parse_copy_statement(
707 "COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)",
708 );
709 assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
710
711 let statement = parse_copy_statement(
712 "COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)",
713 );
714 assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
715 }
716}