1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::error::Result as RecordBatchResult;
21use common_recordbatch::RecordBatch;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::schema::SchemaRef;
26use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
27use pgwire::api::portal::{Format, Portal};
28use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
29use pgwire::api::results::{
30 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
31};
32use pgwire::api::stmt::{QueryParser, StoredStatement};
33use pgwire::api::{ClientInfo, ErrorHandler, Type};
34use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
35use pgwire::messages::PgWireBackendMessage;
36use query::query_engine::DescribeResult;
37use session::context::QueryContextRef;
38use session::Session;
39use snafu::ResultExt;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::{DataFusionSnafu, Result};
44use crate::postgres::types::*;
45use crate::postgres::utils::convert_err;
46use crate::postgres::{fixtures, PostgresServerHandlerInner};
47use crate::query_handler::sql::ServerSqlQueryHandlerRef;
48use crate::SqlPlan;
49
50#[async_trait]
51impl SimpleQueryHandler for PostgresServerHandlerInner {
52 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
53 async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
54 where
55 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
56 C::Error: Debug,
57 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
58 {
59 let query_ctx = self.session.new_query_context();
60 let db = query_ctx.get_db_string();
61 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
62 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
63 .start_timer();
64
65 if query.is_empty() {
66 return Ok(vec![Response::EmptyQuery]);
68 }
69
70 let query = fixtures::rewrite_sql(query);
71 let query = query.as_ref();
72
73 if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
74 send_warning_opt(client, query_ctx).await?;
75 Ok(resps)
76 } else {
77 let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
78
79 let mut results = Vec::with_capacity(outputs.len());
80
81 for output in outputs {
82 let resp =
83 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
84 results.push(resp);
85 }
86
87 send_warning_opt(client, query_ctx).await?;
88 Ok(results)
89 }
90 }
91}
92
93async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
94where
95 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
96 C::Error: Debug,
97 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
98{
99 if let Some(warning) = query_context.warning() {
100 client
101 .feed(PgWireBackendMessage::NoticeResponse(
102 ErrorInfo::new(
103 PgErrorSeverity::Warning.to_string(),
104 PgErrorCode::Ec01000.code(),
105 warning.to_string(),
106 )
107 .into(),
108 ))
109 .await?;
110 }
111
112 Ok(())
113}
114
115pub(crate) fn output_to_query_response<'a>(
116 query_ctx: QueryContextRef,
117 output: Result<Output>,
118 field_format: &Format,
119) -> PgWireResult<Response<'a>> {
120 match output {
121 Ok(o) => match o.data {
122 OutputData::AffectedRows(rows) => {
123 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
124 }
125 OutputData::Stream(record_stream) => {
126 let schema = record_stream.schema();
127 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
128 }
129 OutputData::RecordBatches(recordbatches) => {
130 let schema = recordbatches.schema();
131 recordbatches_to_query_response(
132 query_ctx,
133 recordbatches.as_stream(),
134 schema,
135 field_format,
136 )
137 }
138 },
139 Err(e) => Err(convert_err(e)),
140 }
141}
142
143fn recordbatches_to_query_response<'a, S>(
144 query_ctx: QueryContextRef,
145 recordbatches_stream: S,
146 schema: SchemaRef,
147 field_format: &Format,
148) -> PgWireResult<Response<'a>>
149where
150 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
151{
152 let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
153 let pg_schema_ref = pg_schema.clone();
154 let data_row_stream = recordbatches_stream
155 .map(|record_batch_result| match record_batch_result {
156 Ok(rb) => stream::iter(
157 rb.rows().map(Ok).collect::<Vec<_>>(),
160 )
161 .boxed(),
162 Err(e) => stream::once(future::err(convert_err(e))).boxed(),
163 })
164 .flatten() .map(move |row| {
166 row.and_then(|row| {
167 let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
168 for (value, column) in row.iter().zip(schema.column_schemas()) {
169 encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
170 }
171 encoder.finish()
172 })
173 });
174
175 Ok(Response::Query(QueryResponse::new(
176 pg_schema,
177 data_row_stream,
178 )))
179}
180
181pub struct DefaultQueryParser {
182 query_handler: ServerSqlQueryHandlerRef,
183 session: Arc<Session>,
184}
185
186impl DefaultQueryParser {
187 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
188 DefaultQueryParser {
189 query_handler,
190 session,
191 }
192 }
193}
194
195#[async_trait]
196impl QueryParser for DefaultQueryParser {
197 type Statement = SqlPlan;
198
199 async fn parse_sql<C>(
200 &self,
201 _client: &C,
202 sql: &str,
203 _types: &[Type],
204 ) -> PgWireResult<Self::Statement> {
205 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
206 let query_ctx = self.session.new_query_context();
207
208 if sql.is_empty() || fixtures::matches(sql) {
210 return Ok(SqlPlan {
211 query: sql.to_owned(),
212 statement: None,
213 plan: None,
214 schema: None,
215 });
216 }
217
218 let sql = fixtures::rewrite_sql(sql);
219 let sql = sql.as_ref();
220
221 let mut stmts =
222 ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
223 .map_err(convert_err)?;
224 if stmts.len() != 1 {
225 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
226 PgErrorCode::Ec42P14,
227 ))))
228 } else {
229 let stmt = stmts.remove(0);
230
231 let describe_result = self
232 .query_handler
233 .do_describe(stmt.clone(), query_ctx)
234 .await
235 .map_err(convert_err)?;
236
237 let (plan, schema) = if let Some(DescribeResult {
238 logical_plan,
239 schema,
240 }) = describe_result
241 {
242 (Some(logical_plan), Some(schema))
243 } else {
244 (None, None)
245 };
246
247 Ok(SqlPlan {
248 query: sql.to_owned(),
249 statement: Some(stmt),
250 plan,
251 schema,
252 })
253 }
254 }
255}
256
257#[async_trait]
258impl ExtendedQueryHandler for PostgresServerHandlerInner {
259 type Statement = SqlPlan;
260 type QueryParser = DefaultQueryParser;
261
262 fn query_parser(&self) -> Arc<Self::QueryParser> {
263 self.query_parser.clone()
264 }
265
266 async fn do_query<'a, C>(
267 &self,
268 client: &mut C,
269 portal: &Portal<Self::Statement>,
270 _max_rows: usize,
271 ) -> PgWireResult<Response<'a>>
272 where
273 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
274 C::Error: Debug,
275 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
276 {
277 let query_ctx = self.session.new_query_context();
278 let db = query_ctx.get_db_string();
279 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
280 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
281 .start_timer();
282
283 let sql_plan = &portal.statement.statement;
284
285 if sql_plan.query.is_empty() {
286 return Ok(Response::EmptyQuery);
288 }
289
290 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
291 send_warning_opt(client, query_ctx).await?;
292 return Ok(resps.remove(0));
294 }
295
296 let output = if let Some(plan) = &sql_plan.plan {
297 let plan = plan
298 .clone()
299 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
300 plan, portal,
301 )?))
302 .context(DataFusionSnafu)
303 .map_err(convert_err)?;
304 self.query_handler
305 .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
306 .await
307 } else {
308 let mut sql = sql_plan.query.clone();
312 for i in 0..portal.parameter_len() {
313 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
314 }
315
316 self.query_handler
317 .do_query(&sql, query_ctx.clone())
318 .await
319 .remove(0)
320 };
321
322 send_warning_opt(client, query_ctx.clone()).await?;
323 output_to_query_response(query_ctx, output, &portal.result_column_format)
324 }
325
326 async fn do_describe_statement<C>(
327 &self,
328 _client: &mut C,
329 stmt: &StoredStatement<Self::Statement>,
330 ) -> PgWireResult<DescribeStatementResponse>
331 where
332 C: ClientInfo + Unpin + Send + Sync,
333 {
334 let sql_plan = &stmt.statement;
335 let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
336 let param_types = plan
337 .get_parameter_types()
338 .context(DataFusionSnafu)
339 .map_err(convert_err)?
340 .into_iter()
341 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
342 .collect();
343
344 let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?;
345
346 (types, sql_plan, &Format::UnifiedBinary)
347 } else {
348 let param_types = stmt.parameter_types.clone();
349 (param_types, sql_plan, &Format::UnifiedBinary)
350 };
351
352 if let Some(schema) = &sql_plan.schema {
353 schema_to_pg(schema, format)
354 .map(|fields| DescribeStatementResponse::new(param_types, fields))
355 .map_err(convert_err)
356 } else {
357 if let Some(mut resp) =
358 fixtures::process(&sql_plan.query, self.session.new_query_context())
359 {
360 if let Response::Query(query_response) = resp.remove(0) {
361 return Ok(DescribeStatementResponse::new(
362 param_types,
363 (*query_response.row_schema()).clone(),
364 ));
365 }
366 }
367
368 Ok(DescribeStatementResponse::new(param_types, vec![]))
369 }
370 }
371
372 async fn do_describe_portal<C>(
373 &self,
374 _client: &mut C,
375 portal: &Portal<Self::Statement>,
376 ) -> PgWireResult<DescribePortalResponse>
377 where
378 C: ClientInfo + Unpin + Send + Sync,
379 {
380 let sql_plan = &portal.statement.statement;
381 let format = &portal.result_column_format;
382
383 if let Some(schema) = &sql_plan.schema {
384 schema_to_pg(schema, format)
385 .map(DescribePortalResponse::new)
386 .map_err(convert_err)
387 } else {
388 if let Some(mut resp) =
389 fixtures::process(&sql_plan.query, self.session.new_query_context())
390 {
391 if let Response::Query(query_response) = resp.remove(0) {
392 return Ok(DescribePortalResponse::new(
393 (*query_response.row_schema()).clone(),
394 ));
395 }
396 }
397
398 Ok(DescribePortalResponse::new(vec![]))
399 }
400 }
401}
402
403impl ErrorHandler for PostgresServerHandlerInner {
404 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
405 where
406 C: ClientInfo,
407 {
408 debug!("Postgres interface error {}", error)
409 }
410}