1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43
44use crate::SqlPlan;
45use crate::error::{DataFusionSnafu, Result};
46use crate::postgres::types::*;
47use crate::postgres::utils::convert_err;
48use crate::postgres::{PostgresServerHandlerInner, fixtures};
49use crate::query_handler::sql::ServerSqlQueryHandlerRef;
50
51#[async_trait]
52impl SimpleQueryHandler for PostgresServerHandlerInner {
53 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
54 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
55 where
56 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
57 C::Error: Debug,
58 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
59 {
60 let query_ctx = self.session.new_query_context();
61 let db = query_ctx.get_db_string();
62 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
63 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
64 .start_timer();
65
66 if query.is_empty() {
67 return Ok(vec![Response::EmptyQuery]);
69 }
70
71 let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
72 statements
73 .iter()
74 .map(|s| s.to_string())
75 .collect::<Vec<_>>()
76 .join(";")
77 } else {
78 query.to_string()
79 };
80
81 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
82 send_warning_opt(client, query_ctx).await?;
83 Ok(resps)
84 } else {
85 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
86
87 let mut results = Vec::with_capacity(outputs.len());
88
89 for output in outputs {
90 let resp =
91 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
92 results.push(resp);
93 }
94
95 send_warning_opt(client, query_ctx).await?;
96 Ok(results)
97 }
98 }
99}
100
101async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
102where
103 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
104 C::Error: Debug,
105 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
106{
107 if let Some(warning) = query_context.warning() {
108 client
109 .feed(PgWireBackendMessage::NoticeResponse(
110 ErrorInfo::new(
111 PgErrorSeverity::Warning.to_string(),
112 PgErrorCode::Ec01000.code(),
113 warning.clone(),
114 )
115 .into(),
116 ))
117 .await?;
118 }
119
120 Ok(())
121}
122
123pub(crate) fn output_to_query_response(
124 query_ctx: QueryContextRef,
125 output: Result<Output>,
126 field_format: &Format,
127) -> PgWireResult<Response> {
128 match output {
129 Ok(o) => match o.data {
130 OutputData::AffectedRows(rows) => {
131 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
132 }
133 OutputData::Stream(record_stream) => {
134 let schema = record_stream.schema();
135 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
136 }
137 OutputData::RecordBatches(recordbatches) => {
138 let schema = recordbatches.schema();
139 recordbatches_to_query_response(
140 query_ctx,
141 recordbatches.as_stream(),
142 schema,
143 field_format,
144 )
145 }
146 },
147 Err(e) => Err(convert_err(e)),
148 }
149}
150
151fn recordbatches_to_query_response<S>(
152 query_ctx: QueryContextRef,
153 recordbatches_stream: S,
154 schema: SchemaRef,
155 field_format: &Format,
156) -> PgWireResult<Response>
157where
158 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
159{
160 let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
161 let pg_schema_ref = pg_schema.clone();
162 let data_row_stream = recordbatches_stream
163 .map(|record_batch_result| match record_batch_result {
164 Ok(rb) => stream::iter(
165 rb.rows().map(Ok).collect::<Vec<_>>(),
168 )
169 .boxed(),
170 Err(e) => stream::once(future::err(convert_err(e))).boxed(),
171 })
172 .flatten() .map(move |row| {
174 row.and_then(|row| {
175 let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
176 for (value, column) in row.iter().zip(schema.column_schemas()) {
177 encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
178 }
179 encoder.finish()
180 })
181 });
182
183 Ok(Response::Query(QueryResponse::new(
184 pg_schema,
185 data_row_stream,
186 )))
187}
188
189pub struct DefaultQueryParser {
190 query_handler: ServerSqlQueryHandlerRef,
191 session: Arc<Session>,
192 compatibility_parser: PostgresCompatibilityParser,
193}
194
195impl DefaultQueryParser {
196 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
197 DefaultQueryParser {
198 query_handler,
199 session,
200 compatibility_parser: PostgresCompatibilityParser::new(),
201 }
202 }
203}
204
205#[async_trait]
206impl QueryParser for DefaultQueryParser {
207 type Statement = SqlPlan;
208
209 async fn parse_sql<C>(
210 &self,
211 _client: &C,
212 sql: &str,
213 _types: &[Type],
214 ) -> PgWireResult<Self::Statement> {
215 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
216 let query_ctx = self.session.new_query_context();
217
218 if sql.is_empty() || fixtures::matches(sql) {
220 return Ok(SqlPlan {
221 query: sql.to_owned(),
222 statement: None,
223 plan: None,
224 schema: None,
225 });
226 }
227
228 let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
229 statements.remove(0).to_string()
230 } else {
231 sql.to_string()
234 };
235
236 let mut stmts = ParserContext::create_with_dialect(
237 &sql,
238 &PostgreSqlDialect {},
239 ParseOptions::default(),
240 )
241 .map_err(convert_err)?;
242 if stmts.len() != 1 {
243 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
244 PgErrorCode::Ec42P14,
245 ))))
246 } else {
247 let stmt = stmts.remove(0);
248
249 let describe_result = self
250 .query_handler
251 .do_describe(stmt.clone(), query_ctx)
252 .await
253 .map_err(convert_err)?;
254
255 let (plan, schema) = if let Some(DescribeResult {
256 logical_plan,
257 schema,
258 }) = describe_result
259 {
260 (Some(logical_plan), Some(schema))
261 } else {
262 (None, None)
263 };
264
265 Ok(SqlPlan {
266 query: sql.clone(),
267 statement: Some(stmt),
268 plan,
269 schema,
270 })
271 }
272 }
273}
274
275#[async_trait]
276impl ExtendedQueryHandler for PostgresServerHandlerInner {
277 type Statement = SqlPlan;
278 type QueryParser = DefaultQueryParser;
279
280 fn query_parser(&self) -> Arc<Self::QueryParser> {
281 self.query_parser.clone()
282 }
283
284 async fn do_query<C>(
285 &self,
286 client: &mut C,
287 portal: &Portal<Self::Statement>,
288 _max_rows: usize,
289 ) -> PgWireResult<Response>
290 where
291 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
292 C::Error: Debug,
293 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
294 {
295 let query_ctx = self.session.new_query_context();
296 let db = query_ctx.get_db_string();
297 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
298 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
299 .start_timer();
300
301 let sql_plan = &portal.statement.statement;
302
303 if sql_plan.query.is_empty() {
304 return Ok(Response::EmptyQuery);
306 }
307
308 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
309 send_warning_opt(client, query_ctx).await?;
310 return Ok(resps.remove(0));
312 }
313
314 let output = if let Some(plan) = &sql_plan.plan {
315 let plan = plan
316 .clone()
317 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
318 plan, portal,
319 )?))
320 .context(DataFusionSnafu)
321 .map_err(convert_err)?;
322 self.query_handler
323 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
324 .await
325 } else {
326 let mut sql = sql_plan.query.clone();
330 for i in 0..portal.parameter_len() {
331 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
332 }
333
334 self.query_handler
335 .do_query(&sql, query_ctx.clone())
336 .await
337 .remove(0)
338 };
339
340 send_warning_opt(client, query_ctx.clone()).await?;
341 output_to_query_response(query_ctx, output, &portal.result_column_format)
342 }
343
344 async fn do_describe_statement<C>(
345 &self,
346 _client: &mut C,
347 stmt: &StoredStatement<Self::Statement>,
348 ) -> PgWireResult<DescribeStatementResponse>
349 where
350 C: ClientInfo + Unpin + Send + Sync,
351 {
352 let sql_plan = &stmt.statement;
353 let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
354 let param_types = plan
355 .get_parameter_types()
356 .context(DataFusionSnafu)
357 .map_err(convert_err)?
358 .into_iter()
359 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
360 .collect();
361
362 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
363
364 (types, sql_plan, &Format::UnifiedBinary)
365 } else {
366 let param_types = stmt.parameter_types.clone();
367 (param_types, sql_plan, &Format::UnifiedBinary)
368 };
369
370 if let Some(schema) = &sql_plan.schema {
371 schema_to_pg(schema, format)
372 .map(|fields| DescribeStatementResponse::new(param_types, fields))
373 .map_err(convert_err)
374 } else {
375 if let Some(mut resp) =
376 fixtures::process(&sql_plan.query, self.session.new_query_context())
377 && let Response::Query(query_response) = resp.remove(0)
378 {
379 return Ok(DescribeStatementResponse::new(
380 param_types,
381 (*query_response.row_schema()).clone(),
382 ));
383 }
384
385 Ok(DescribeStatementResponse::new(param_types, vec![]))
386 }
387 }
388
389 async fn do_describe_portal<C>(
390 &self,
391 _client: &mut C,
392 portal: &Portal<Self::Statement>,
393 ) -> PgWireResult<DescribePortalResponse>
394 where
395 C: ClientInfo + Unpin + Send + Sync,
396 {
397 let sql_plan = &portal.statement.statement;
398 let format = &portal.result_column_format;
399
400 if let Some(schema) = &sql_plan.schema {
401 schema_to_pg(schema, format)
402 .map(DescribePortalResponse::new)
403 .map_err(convert_err)
404 } else {
405 if let Some(mut resp) =
406 fixtures::process(&sql_plan.query, self.session.new_query_context())
407 && let Response::Query(query_response) = resp.remove(0)
408 {
409 return Ok(DescribePortalResponse::new(
410 (*query_response.row_schema()).clone(),
411 ));
412 }
413
414 Ok(DescribePortalResponse::new(vec![]))
415 }
416 }
417}
418
419impl ErrorHandler for PostgresServerHandlerInner {
420 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
421 where
422 C: ClientInfo,
423 {
424 debug!("Postgres interface error {}", error)
425 }
426}