1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43
44use crate::SqlPlan;
45use crate::error::{DataFusionSnafu, Result};
46use crate::postgres::types::*;
47use crate::postgres::utils::convert_err;
48use crate::postgres::{PostgresServerHandlerInner, fixtures};
49use crate::query_handler::sql::ServerSqlQueryHandlerRef;
50
51#[async_trait]
52impl SimpleQueryHandler for PostgresServerHandlerInner {
53 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
54 async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
55 where
56 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
57 C::Error: Debug,
58 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
59 {
60 let query_ctx = self.session.new_query_context();
61 let db = query_ctx.get_db_string();
62 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
63 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
64 .start_timer();
65
66 if query.is_empty() {
67 return Ok(vec![Response::EmptyQuery]);
69 }
70
71 let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
72 statements
73 .iter()
74 .map(|s| s.to_string())
75 .collect::<Vec<_>>()
76 .join(";")
77 } else {
78 query.to_string()
79 };
80
81 if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
82 send_warning_opt(client, query_ctx).await?;
83 Ok(resps)
84 } else {
85 let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
86
87 let mut results = Vec::with_capacity(outputs.len());
88
89 for output in outputs {
90 let resp =
91 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
92 results.push(resp);
93 }
94
95 send_warning_opt(client, query_ctx).await?;
96 Ok(results)
97 }
98 }
99}
100
101async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
102where
103 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
104 C::Error: Debug,
105 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
106{
107 if let Some(warning) = query_context.warning() {
108 client
109 .feed(PgWireBackendMessage::NoticeResponse(
110 ErrorInfo::new(
111 PgErrorSeverity::Warning.to_string(),
112 PgErrorCode::Ec01000.code(),
113 warning.clone(),
114 )
115 .into(),
116 ))
117 .await?;
118 }
119
120 Ok(())
121}
122
123pub(crate) fn output_to_query_response(
124 query_ctx: QueryContextRef,
125 output: Result<Output>,
126 field_format: &Format,
127) -> PgWireResult<Response> {
128 match output {
129 Ok(o) => match o.data {
130 OutputData::AffectedRows(rows) => {
131 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
132 }
133 OutputData::Stream(record_stream) => {
134 let schema = record_stream.schema();
135 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
136 }
137 OutputData::RecordBatches(recordbatches) => {
138 let schema = recordbatches.schema();
139 recordbatches_to_query_response(
140 query_ctx,
141 recordbatches.as_stream(),
142 schema,
143 field_format,
144 )
145 }
146 },
147 Err(e) => Err(convert_err(e)),
148 }
149}
150
151fn recordbatches_to_query_response<S>(
152 query_ctx: QueryContextRef,
153 recordbatches_stream: S,
154 schema: SchemaRef,
155 field_format: &Format,
156) -> PgWireResult<Response>
157where
158 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
159{
160 let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
161 let pg_schema_ref = pg_schema.clone();
162 let data_row_stream = recordbatches_stream
163 .map(move |result| match result {
164 Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
165 query_ctx.clone(),
166 pg_schema_ref.clone(),
167 record_batch,
168 ))
169 .boxed(),
170 Err(e) => stream::once(future::err(convert_err(e))).boxed(),
171 })
172 .flatten();
173
174 Ok(Response::Query(QueryResponse::new(
175 pg_schema,
176 data_row_stream,
177 )))
178}
179
180pub struct DefaultQueryParser {
181 query_handler: ServerSqlQueryHandlerRef,
182 session: Arc<Session>,
183 compatibility_parser: PostgresCompatibilityParser,
184}
185
186impl DefaultQueryParser {
187 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
188 DefaultQueryParser {
189 query_handler,
190 session,
191 compatibility_parser: PostgresCompatibilityParser::new(),
192 }
193 }
194}
195
196#[async_trait]
197impl QueryParser for DefaultQueryParser {
198 type Statement = SqlPlan;
199
200 async fn parse_sql<C>(
201 &self,
202 _client: &C,
203 sql: &str,
204 _types: &[Type],
205 ) -> PgWireResult<Self::Statement> {
206 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
207 let query_ctx = self.session.new_query_context();
208
209 if sql.is_empty() || fixtures::matches(sql) {
211 return Ok(SqlPlan {
212 query: sql.to_owned(),
213 statement: None,
214 plan: None,
215 schema: None,
216 });
217 }
218
219 let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
220 statements.remove(0).to_string()
221 } else {
222 sql.to_string()
225 };
226
227 let mut stmts = ParserContext::create_with_dialect(
228 &sql,
229 &PostgreSqlDialect {},
230 ParseOptions::default(),
231 )
232 .map_err(convert_err)?;
233 if stmts.len() != 1 {
234 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
235 PgErrorCode::Ec42P14,
236 ))))
237 } else {
238 let stmt = stmts.remove(0);
239
240 let describe_result = self
241 .query_handler
242 .do_describe(stmt.clone(), query_ctx)
243 .await
244 .map_err(convert_err)?;
245
246 let (plan, schema) = if let Some(DescribeResult {
247 logical_plan,
248 schema,
249 }) = describe_result
250 {
251 (Some(logical_plan), Some(schema))
252 } else {
253 (None, None)
254 };
255
256 Ok(SqlPlan {
257 query: sql.clone(),
258 statement: Some(stmt),
259 plan,
260 schema,
261 })
262 }
263 }
264}
265
266#[async_trait]
267impl ExtendedQueryHandler for PostgresServerHandlerInner {
268 type Statement = SqlPlan;
269 type QueryParser = DefaultQueryParser;
270
271 fn query_parser(&self) -> Arc<Self::QueryParser> {
272 self.query_parser.clone()
273 }
274
275 async fn do_query<C>(
276 &self,
277 client: &mut C,
278 portal: &Portal<Self::Statement>,
279 _max_rows: usize,
280 ) -> PgWireResult<Response>
281 where
282 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
283 C::Error: Debug,
284 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
285 {
286 let query_ctx = self.session.new_query_context();
287 let db = query_ctx.get_db_string();
288 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
289 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
290 .start_timer();
291
292 let sql_plan = &portal.statement.statement;
293
294 if sql_plan.query.is_empty() {
295 return Ok(Response::EmptyQuery);
297 }
298
299 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
300 send_warning_opt(client, query_ctx).await?;
301 return Ok(resps.remove(0));
303 }
304
305 let output = if let Some(plan) = &sql_plan.plan {
306 let plan = plan
307 .clone()
308 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
309 plan, portal,
310 )?))
311 .context(DataFusionSnafu)
312 .map_err(convert_err)?;
313 self.query_handler
314 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
315 .await
316 } else {
317 let mut sql = sql_plan.query.clone();
321 for i in 0..portal.parameter_len() {
322 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
323 }
324
325 self.query_handler
326 .do_query(&sql, query_ctx.clone())
327 .await
328 .remove(0)
329 };
330
331 send_warning_opt(client, query_ctx.clone()).await?;
332 output_to_query_response(query_ctx, output, &portal.result_column_format)
333 }
334
335 async fn do_describe_statement<C>(
336 &self,
337 _client: &mut C,
338 stmt: &StoredStatement<Self::Statement>,
339 ) -> PgWireResult<DescribeStatementResponse>
340 where
341 C: ClientInfo + Unpin + Send + Sync,
342 {
343 let sql_plan = &stmt.statement;
344 let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
345 let param_types = plan
346 .get_parameter_types()
347 .context(DataFusionSnafu)
348 .map_err(convert_err)?
349 .into_iter()
350 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
351 .collect();
352
353 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
354
355 (types, sql_plan, &Format::UnifiedBinary)
356 } else {
357 let param_types = stmt.parameter_types.clone();
358 (param_types, sql_plan, &Format::UnifiedBinary)
359 };
360
361 if let Some(schema) = &sql_plan.schema {
362 schema_to_pg(schema, format)
363 .map(|fields| DescribeStatementResponse::new(param_types, fields))
364 .map_err(convert_err)
365 } else {
366 if let Some(mut resp) =
367 fixtures::process(&sql_plan.query, self.session.new_query_context())
368 && let Response::Query(query_response) = resp.remove(0)
369 {
370 return Ok(DescribeStatementResponse::new(
371 param_types,
372 (*query_response.row_schema()).clone(),
373 ));
374 }
375
376 Ok(DescribeStatementResponse::new(param_types, vec![]))
377 }
378 }
379
380 async fn do_describe_portal<C>(
381 &self,
382 _client: &mut C,
383 portal: &Portal<Self::Statement>,
384 ) -> PgWireResult<DescribePortalResponse>
385 where
386 C: ClientInfo + Unpin + Send + Sync,
387 {
388 let sql_plan = &portal.statement.statement;
389 let format = &portal.result_column_format;
390
391 if let Some(schema) = &sql_plan.schema {
392 schema_to_pg(schema, format)
393 .map(DescribePortalResponse::new)
394 .map_err(convert_err)
395 } else {
396 if let Some(mut resp) =
397 fixtures::process(&sql_plan.query, self.session.new_query_context())
398 && let Response::Query(query_response) = resp.remove(0)
399 {
400 return Ok(DescribePortalResponse::new(
401 (*query_response.row_schema()).clone(),
402 ));
403 }
404
405 Ok(DescribePortalResponse::new(vec![]))
406 }
407 }
408}
409
410impl ErrorHandler for PostgresServerHandlerInner {
411 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
412 where
413 C: ClientInfo,
414 {
415 debug!("Postgres interface error {}", error)
416 }
417}