1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43use sql::statements::statement::Statement;
44
45use crate::SqlPlan;
46use crate::error::{DataFusionSnafu, Result};
47use crate::postgres::types::*;
48use crate::postgres::utils::convert_err;
49use crate::postgres::{PostgresServerHandlerInner, fixtures};
50use crate::query_handler::sql::ServerSqlQueryHandlerRef;
51
52#[async_trait]
53impl SimpleQueryHandler for PostgresServerHandlerInner {
54 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
55 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
56 where
57 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
58 C::Error: Debug,
59 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
60 {
61 let query_ctx = self.session.new_query_context();
62 let db = query_ctx.get_db_string();
63 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
64 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
65 .start_timer();
66
67 if query.is_empty() {
68 return Ok(vec![Response::EmptyQuery]);
70 }
71
72 let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
73 statements
74 .iter()
75 .map(|s| s.to_string())
76 .collect::<Vec<_>>()
77 .join(";")
78 } else {
79 query.to_string()
80 };
81
82 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
83 send_warning_opt(client, query_ctx).await?;
84 Ok(resps)
85 } else {
86 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
87
88 let mut results = Vec::with_capacity(outputs.len());
89
90 for output in outputs {
91 let resp =
92 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
93 results.push(resp);
94 }
95
96 send_warning_opt(client, query_ctx).await?;
97 Ok(results)
98 }
99 }
100}
101
102async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
103where
104 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
105 C::Error: Debug,
106 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
107{
108 if let Some(warning) = query_context.warning() {
109 client
110 .feed(PgWireBackendMessage::NoticeResponse(
111 ErrorInfo::new(
112 PgErrorSeverity::Warning.to_string(),
113 PgErrorCode::Ec01000.code(),
114 warning.clone(),
115 )
116 .into(),
117 ))
118 .await?;
119 }
120
121 Ok(())
122}
123
124pub(crate) fn output_to_query_response(
125 query_ctx: QueryContextRef,
126 output: Result<Output>,
127 field_format: &Format,
128) -> PgWireResult<Response> {
129 match output {
130 Ok(o) => match o.data {
131 OutputData::AffectedRows(rows) => {
132 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
133 }
134 OutputData::Stream(record_stream) => {
135 let schema = record_stream.schema();
136 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
137 }
138 OutputData::RecordBatches(recordbatches) => {
139 let schema = recordbatches.schema();
140 recordbatches_to_query_response(
141 query_ctx,
142 recordbatches.as_stream(),
143 schema,
144 field_format,
145 )
146 }
147 },
148 Err(e) => Err(convert_err(e)),
149 }
150}
151
152fn recordbatches_to_query_response<S>(
153 query_ctx: QueryContextRef,
154 recordbatches_stream: S,
155 schema: SchemaRef,
156 field_format: &Format,
157) -> PgWireResult<Response>
158where
159 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
160{
161 let format_options = format_options_from_query_ctx(&query_ctx);
162 let pg_schema = Arc::new(
163 schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
164 );
165 let pg_schema_ref = pg_schema.clone();
166 let data_row_stream = recordbatches_stream
167 .map(move |result| match result {
168 Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
169 query_ctx.clone(),
170 pg_schema_ref.clone(),
171 record_batch,
172 ))
173 .boxed(),
174 Err(e) => stream::once(future::err(convert_err(e))).boxed(),
175 })
176 .flatten();
177
178 Ok(Response::Query(QueryResponse::new(
179 pg_schema,
180 data_row_stream,
181 )))
182}
183
184pub struct DefaultQueryParser {
185 query_handler: ServerSqlQueryHandlerRef,
186 session: Arc<Session>,
187 compatibility_parser: PostgresCompatibilityParser,
188}
189
190impl DefaultQueryParser {
191 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
192 DefaultQueryParser {
193 query_handler,
194 session,
195 compatibility_parser: PostgresCompatibilityParser::new(),
196 }
197 }
198}
199
200#[async_trait]
201impl QueryParser for DefaultQueryParser {
202 type Statement = SqlPlan;
203
204 async fn parse_sql<C>(
205 &self,
206 _client: &C,
207 sql: &str,
208 _types: &[Option<Type>],
209 ) -> PgWireResult<Self::Statement> {
210 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
211 let query_ctx = self.session.new_query_context();
212
213 if sql.is_empty() || fixtures::matches(sql) {
215 return Ok(SqlPlan {
216 query: sql.to_owned(),
217 statement: None,
218 plan: None,
219 schema: None,
220 });
221 }
222
223 let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
224 statements.remove(0).to_string()
225 } else {
226 sql.to_string()
229 };
230
231 let mut stmts = ParserContext::create_with_dialect(
232 &sql,
233 &PostgreSqlDialect {},
234 ParseOptions::default(),
235 )
236 .map_err(convert_err)?;
237 if stmts.len() != 1 {
238 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
239 PgErrorCode::Ec42P14,
240 ))))
241 } else {
242 let stmt = stmts.remove(0);
243
244 let describe_result = self
245 .query_handler
246 .do_describe(stmt.clone(), query_ctx)
247 .await
248 .map_err(convert_err)?;
249
250 let (plan, schema) = if let Some(DescribeResult {
251 logical_plan,
252 schema,
253 }) = describe_result
254 {
255 (Some(logical_plan), Some(schema))
256 } else {
257 (None, None)
258 };
259
260 Ok(SqlPlan {
261 query: sql.clone(),
262 statement: Some(stmt),
263 plan,
264 schema,
265 })
266 }
267 }
268
269 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
270 Err(PgWireError::ApiError(
273 "get_parameter_types is not expected to be called".into(),
274 ))
275 }
276
277 fn get_result_schema(
278 &self,
279 _stmt: &Self::Statement,
280 _column_format: Option<&Format>,
281 ) -> PgWireResult<Vec<FieldInfo>> {
282 Err(PgWireError::ApiError(
285 "get_result_schema is not expected to be called".into(),
286 ))
287 }
288}
289
290#[async_trait]
291impl ExtendedQueryHandler for PostgresServerHandlerInner {
292 type Statement = SqlPlan;
293 type QueryParser = DefaultQueryParser;
294
295 fn query_parser(&self) -> Arc<Self::QueryParser> {
296 self.query_parser.clone()
297 }
298
299 async fn do_query<C>(
300 &self,
301 client: &mut C,
302 portal: &Portal<Self::Statement>,
303 _max_rows: usize,
304 ) -> PgWireResult<Response>
305 where
306 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
307 C::Error: Debug,
308 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
309 {
310 let query_ctx = self.session.new_query_context();
311 let db = query_ctx.get_db_string();
312 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
313 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
314 .start_timer();
315
316 let sql_plan = &portal.statement.statement;
317
318 if sql_plan.query.is_empty() {
319 return Ok(Response::EmptyQuery);
321 }
322
323 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
324 send_warning_opt(client, query_ctx).await?;
325 return Ok(resps.remove(0));
327 }
328
329 let output = if let Some(plan) = &sql_plan.plan {
330 let values = parameters_to_scalar_values(plan, portal)?;
331 let plan = plan
332 .clone()
333 .replace_params_with_values(&ParamValues::List(
334 values.into_iter().map(Into::into).collect(),
335 ))
336 .context(DataFusionSnafu)
337 .map_err(convert_err)?;
338 self.query_handler
339 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
340 .await
341 } else {
342 let mut sql = sql_plan.query.clone();
346 for i in 0..portal.parameter_len() {
347 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
348 }
349
350 self.query_handler
351 .do_query(&sql, query_ctx.clone())
352 .await
353 .remove(0)
354 };
355
356 send_warning_opt(client, query_ctx.clone()).await?;
357 output_to_query_response(query_ctx, output, &portal.result_column_format)
358 }
359
360 async fn do_describe_statement<C>(
361 &self,
362 _client: &mut C,
363 stmt: &StoredStatement<Self::Statement>,
364 ) -> PgWireResult<DescribeStatementResponse>
365 where
366 C: ClientInfo + Unpin + Send + Sync,
367 {
368 let sql_plan = &stmt.statement;
369 let provided_param_types = &stmt.parameter_types;
371 let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
372 let param_types = plan
373 .get_parameter_types()
374 .context(DataFusionSnafu)
375 .map_err(convert_err)?
376 .into_iter()
377 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
378 .collect();
379
380 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
381
382 Some(types)
383 } else {
384 None
385 };
386
387 let param_count = if provided_param_types.is_empty() {
388 server_inferenced_types
389 .as_ref()
390 .map(|types| types.len())
391 .unwrap_or(0)
392 } else {
393 provided_param_types.len()
394 };
395
396 let param_types = (0..param_count)
397 .map(|i| {
398 let client_type = provided_param_types.get(i);
399 match client_type {
401 Some(Some(client_type)) => client_type.clone(),
402 _ => server_inferenced_types
403 .as_ref()
404 .and_then(|types| types.get(i).cloned())
405 .unwrap_or(Type::UNKNOWN),
406 }
407 })
408 .collect::<Vec<_>>();
409
410 if let Some(schema) = &sql_plan.schema {
411 schema_to_pg(schema, &Format::UnifiedBinary, None)
412 .map(|fields| DescribeStatementResponse::new(param_types, fields))
413 .map_err(convert_err)
414 } else {
415 if let Some(mut resp) =
416 fixtures::process(&sql_plan.query, self.session.new_query_context())
417 && let Response::Query(query_response) = resp.remove(0)
418 {
419 return Ok(DescribeStatementResponse::new(
420 param_types,
421 (*query_response.row_schema()).clone(),
422 ));
423 }
424
425 Ok(DescribeStatementResponse::new(param_types, vec![]))
426 }
427 }
428
429 async fn do_describe_portal<C>(
430 &self,
431 _client: &mut C,
432 portal: &Portal<Self::Statement>,
433 ) -> PgWireResult<DescribePortalResponse>
434 where
435 C: ClientInfo + Unpin + Send + Sync,
436 {
437 let sql_plan = &portal.statement.statement;
438 let format = &portal.result_column_format;
439
440 match sql_plan.statement.as_ref() {
441 Some(Statement::Query(_)) => {
442 if let Some(schema) = &sql_plan.schema {
444 schema_to_pg(schema, format, None)
445 .map(DescribePortalResponse::new)
446 .map_err(convert_err)
447 } else {
448 Ok(DescribePortalResponse::new(vec![]))
450 }
451 }
452 Some(Statement::ShowCreateDatabase(_))
455 | Some(Statement::ShowCreateTable(_))
456 | Some(Statement::ShowCreateFlow(_))
457 | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
458 FieldInfo::new(
459 "name".to_string(),
460 None,
461 None,
462 Type::TEXT,
463 format.format_for(0),
464 ),
465 FieldInfo::new(
466 "create_statement".to_string(),
467 None,
468 None,
469 Type::TEXT,
470 format.format_for(1),
471 ),
472 ])),
473 Some(Statement::ShowTables(_))
475 | Some(Statement::ShowFlows(_))
476 | Some(Statement::ShowViews(_)) => {
477 Ok(DescribePortalResponse::new(vec![FieldInfo::new(
478 "name".to_string(),
479 None,
480 None,
481 Type::TEXT,
482 format.format_for(0),
483 )]))
484 }
485 _ => {
488 if let Some(mut resp) =
490 fixtures::process(&sql_plan.query, self.session.new_query_context())
491 && let Response::Query(query_response) = resp.remove(0)
492 {
493 Ok(DescribePortalResponse::new(
494 (*query_response.row_schema()).clone(),
495 ))
496 } else {
497 Ok(DescribePortalResponse::new(vec![]))
499 }
500 }
501 }
502 }
503}
504
505impl ErrorHandler for PostgresServerHandlerInner {
506 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
507 where
508 C: ClientInfo,
509 {
510 debug!("Postgres interface error {}", error)
511 }
512}