1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43use sql::statements::statement::Statement;
44
45use crate::SqlPlan;
46use crate::error::{DataFusionSnafu, Result};
47use crate::postgres::types::*;
48use crate::postgres::utils::convert_err;
49use crate::postgres::{PostgresServerHandlerInner, fixtures};
50use crate::query_handler::sql::ServerSqlQueryHandlerRef;
51
52#[async_trait]
53impl SimpleQueryHandler for PostgresServerHandlerInner {
54 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
55 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
56 where
57 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
58 C::Error: Debug,
59 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
60 {
61 let query_ctx = self.session.new_query_context();
62 let db = query_ctx.get_db_string();
63 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
64 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
65 .start_timer();
66
67 if query.is_empty() {
68 return Ok(vec![Response::EmptyQuery]);
70 }
71
72 let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
73 statements
74 .iter()
75 .map(|s| s.to_string())
76 .collect::<Vec<_>>()
77 .join(";")
78 } else {
79 query.to_string()
80 };
81
82 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
83 send_warning_opt(client, query_ctx).await?;
84 Ok(resps)
85 } else {
86 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
87
88 let mut results = Vec::with_capacity(outputs.len());
89
90 for output in outputs {
91 let resp =
92 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
93 results.push(resp);
94 }
95
96 send_warning_opt(client, query_ctx).await?;
97 Ok(results)
98 }
99 }
100}
101
102async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
103where
104 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
105 C::Error: Debug,
106 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
107{
108 if let Some(warning) = query_context.warning() {
109 client
110 .feed(PgWireBackendMessage::NoticeResponse(
111 ErrorInfo::new(
112 PgErrorSeverity::Warning.to_string(),
113 PgErrorCode::Ec01000.code(),
114 warning.clone(),
115 )
116 .into(),
117 ))
118 .await?;
119 }
120
121 Ok(())
122}
123
124pub(crate) fn output_to_query_response(
125 query_ctx: QueryContextRef,
126 output: Result<Output>,
127 field_format: &Format,
128) -> PgWireResult<Response> {
129 match output {
130 Ok(o) => match o.data {
131 OutputData::AffectedRows(rows) => {
132 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
133 }
134 OutputData::Stream(record_stream) => {
135 let schema = record_stream.schema();
136 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
137 }
138 OutputData::RecordBatches(recordbatches) => {
139 let schema = recordbatches.schema();
140 recordbatches_to_query_response(
141 query_ctx,
142 recordbatches.as_stream(),
143 schema,
144 field_format,
145 )
146 }
147 },
148 Err(e) => Err(convert_err(e)),
149 }
150}
151
152fn recordbatches_to_query_response<S>(
153 query_ctx: QueryContextRef,
154 recordbatches_stream: S,
155 schema: SchemaRef,
156 field_format: &Format,
157) -> PgWireResult<Response>
158where
159 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
160{
161 let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
162 let pg_schema_ref = pg_schema.clone();
163 let data_row_stream = recordbatches_stream
164 .map(move |result| match result {
165 Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
166 query_ctx.clone(),
167 pg_schema_ref.clone(),
168 record_batch,
169 ))
170 .boxed(),
171 Err(e) => stream::once(future::err(convert_err(e))).boxed(),
172 })
173 .flatten();
174
175 Ok(Response::Query(QueryResponse::new(
176 pg_schema,
177 data_row_stream,
178 )))
179}
180
181pub struct DefaultQueryParser {
182 query_handler: ServerSqlQueryHandlerRef,
183 session: Arc<Session>,
184 compatibility_parser: PostgresCompatibilityParser,
185}
186
187impl DefaultQueryParser {
188 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
189 DefaultQueryParser {
190 query_handler,
191 session,
192 compatibility_parser: PostgresCompatibilityParser::new(),
193 }
194 }
195}
196
197#[async_trait]
198impl QueryParser for DefaultQueryParser {
199 type Statement = SqlPlan;
200
201 async fn parse_sql<C>(
202 &self,
203 _client: &C,
204 sql: &str,
205 _types: &[Option<Type>],
206 ) -> PgWireResult<Self::Statement> {
207 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
208 let query_ctx = self.session.new_query_context();
209
210 if sql.is_empty() || fixtures::matches(sql) {
212 return Ok(SqlPlan {
213 query: sql.to_owned(),
214 statement: None,
215 plan: None,
216 schema: None,
217 });
218 }
219
220 let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
221 statements.remove(0).to_string()
222 } else {
223 sql.to_string()
226 };
227
228 let mut stmts = ParserContext::create_with_dialect(
229 &sql,
230 &PostgreSqlDialect {},
231 ParseOptions::default(),
232 )
233 .map_err(convert_err)?;
234 if stmts.len() != 1 {
235 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
236 PgErrorCode::Ec42P14,
237 ))))
238 } else {
239 let stmt = stmts.remove(0);
240
241 let describe_result = self
242 .query_handler
243 .do_describe(stmt.clone(), query_ctx)
244 .await
245 .map_err(convert_err)?;
246
247 let (plan, schema) = if let Some(DescribeResult {
248 logical_plan,
249 schema,
250 }) = describe_result
251 {
252 (Some(logical_plan), Some(schema))
253 } else {
254 (None, None)
255 };
256
257 Ok(SqlPlan {
258 query: sql.clone(),
259 statement: Some(stmt),
260 plan,
261 schema,
262 })
263 }
264 }
265
266 fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
267 Err(PgWireError::ApiError(
270 "get_parameter_types is not expected to be called".into(),
271 ))
272 }
273
274 fn get_result_schema(
275 &self,
276 _stmt: &Self::Statement,
277 _column_format: Option<&Format>,
278 ) -> PgWireResult<Vec<FieldInfo>> {
279 Err(PgWireError::ApiError(
282 "get_result_schema is not expected to be called".into(),
283 ))
284 }
285}
286
287#[async_trait]
288impl ExtendedQueryHandler for PostgresServerHandlerInner {
289 type Statement = SqlPlan;
290 type QueryParser = DefaultQueryParser;
291
292 fn query_parser(&self) -> Arc<Self::QueryParser> {
293 self.query_parser.clone()
294 }
295
296 async fn do_query<C>(
297 &self,
298 client: &mut C,
299 portal: &Portal<Self::Statement>,
300 _max_rows: usize,
301 ) -> PgWireResult<Response>
302 where
303 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
304 C::Error: Debug,
305 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
306 {
307 let query_ctx = self.session.new_query_context();
308 let db = query_ctx.get_db_string();
309 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
310 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
311 .start_timer();
312
313 let sql_plan = &portal.statement.statement;
314
315 if sql_plan.query.is_empty() {
316 return Ok(Response::EmptyQuery);
318 }
319
320 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
321 send_warning_opt(client, query_ctx).await?;
322 return Ok(resps.remove(0));
324 }
325
326 let output = if let Some(plan) = &sql_plan.plan {
327 let plan = plan
328 .clone()
329 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
330 plan, portal,
331 )?))
332 .context(DataFusionSnafu)
333 .map_err(convert_err)?;
334 self.query_handler
335 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
336 .await
337 } else {
338 let mut sql = sql_plan.query.clone();
342 for i in 0..portal.parameter_len() {
343 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
344 }
345
346 self.query_handler
347 .do_query(&sql, query_ctx.clone())
348 .await
349 .remove(0)
350 };
351
352 send_warning_opt(client, query_ctx.clone()).await?;
353 output_to_query_response(query_ctx, output, &portal.result_column_format)
354 }
355
356 async fn do_describe_statement<C>(
357 &self,
358 _client: &mut C,
359 stmt: &StoredStatement<Self::Statement>,
360 ) -> PgWireResult<DescribeStatementResponse>
361 where
362 C: ClientInfo + Unpin + Send + Sync,
363 {
364 let sql_plan = &stmt.statement;
365 let provided_param_types = &stmt.parameter_types;
367 let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
368 let param_types = plan
369 .get_parameter_types()
370 .context(DataFusionSnafu)
371 .map_err(convert_err)?
372 .into_iter()
373 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
374 .collect();
375
376 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
377
378 Some(types)
379 } else {
380 None
381 };
382
383 let param_count = if provided_param_types.is_empty() {
384 server_inferenced_types
385 .as_ref()
386 .map(|types| types.len())
387 .unwrap_or(0)
388 } else {
389 provided_param_types.len()
390 };
391
392 let param_types = (0..param_count)
393 .map(|i| {
394 let client_type = provided_param_types.get(i);
395 match client_type {
397 Some(Some(client_type)) => client_type.clone(),
398 _ => server_inferenced_types
399 .as_ref()
400 .and_then(|types| types.get(i).cloned())
401 .unwrap_or(Type::UNKNOWN),
402 }
403 })
404 .collect::<Vec<_>>();
405
406 if let Some(schema) = &sql_plan.schema {
407 schema_to_pg(schema, &Format::UnifiedBinary)
408 .map(|fields| DescribeStatementResponse::new(param_types, fields))
409 .map_err(convert_err)
410 } else {
411 if let Some(mut resp) =
412 fixtures::process(&sql_plan.query, self.session.new_query_context())
413 && let Response::Query(query_response) = resp.remove(0)
414 {
415 return Ok(DescribeStatementResponse::new(
416 param_types,
417 (*query_response.row_schema()).clone(),
418 ));
419 }
420
421 Ok(DescribeStatementResponse::new(param_types, vec![]))
422 }
423 }
424
425 async fn do_describe_portal<C>(
426 &self,
427 _client: &mut C,
428 portal: &Portal<Self::Statement>,
429 ) -> PgWireResult<DescribePortalResponse>
430 where
431 C: ClientInfo + Unpin + Send + Sync,
432 {
433 let sql_plan = &portal.statement.statement;
434 let format = &portal.result_column_format;
435
436 match sql_plan.statement.as_ref() {
437 Some(Statement::Query(_)) => {
438 if let Some(schema) = &sql_plan.schema {
440 schema_to_pg(schema, format)
441 .map(DescribePortalResponse::new)
442 .map_err(convert_err)
443 } else {
444 Ok(DescribePortalResponse::new(vec![]))
446 }
447 }
448 Some(Statement::ShowCreateDatabase(_))
451 | Some(Statement::ShowCreateTable(_))
452 | Some(Statement::ShowCreateFlow(_))
453 | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
454 FieldInfo::new(
455 "name".to_string(),
456 None,
457 None,
458 Type::TEXT,
459 format.format_for(0),
460 ),
461 FieldInfo::new(
462 "create_statement".to_string(),
463 None,
464 None,
465 Type::TEXT,
466 format.format_for(1),
467 ),
468 ])),
469 Some(Statement::ShowTables(_))
471 | Some(Statement::ShowFlows(_))
472 | Some(Statement::ShowViews(_)) => {
473 Ok(DescribePortalResponse::new(vec![FieldInfo::new(
474 "name".to_string(),
475 None,
476 None,
477 Type::TEXT,
478 format.format_for(0),
479 )]))
480 }
481 _ => {
484 if let Some(mut resp) =
486 fixtures::process(&sql_plan.query, self.session.new_query_context())
487 && let Response::Query(query_response) = resp.remove(0)
488 {
489 Ok(DescribePortalResponse::new(
490 (*query_response.row_schema()).clone(),
491 ))
492 } else {
493 Ok(DescribePortalResponse::new(vec![]))
495 }
496 }
497 }
498 }
499}
500
501impl ErrorHandler for PostgresServerHandlerInner {
502 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
503 where
504 C: ClientInfo,
505 {
506 debug!("Postgres interface error {}", error)
507 }
508}