1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_error::ext::ErrorExt;
20use common_query::{Output, OutputData};
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::RecordBatch;
23use common_telemetry::{debug, error, tracing};
24use datafusion_common::ParamValues;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::context::QueryContextRef;
39use session::Session;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::Result;
44use crate::postgres::types::*;
45use crate::postgres::{fixtures, PostgresServerHandlerInner};
46use crate::query_handler::sql::ServerSqlQueryHandlerRef;
47use crate::SqlPlan;
48
49#[async_trait]
50impl SimpleQueryHandler for PostgresServerHandlerInner {
51 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
52 async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
53 where
54 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
55 C::Error: Debug,
56 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
57 {
58 let query_ctx = self.session.new_query_context();
59 let db = query_ctx.get_db_string();
60 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
61 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
62 .start_timer();
63
64 if query.is_empty() {
65 return Ok(vec![Response::EmptyQuery]);
67 }
68
69 let query = fixtures::rewrite_sql(query);
70 let query = query.as_ref();
71
72 if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
73 send_warning_opt(client, query_ctx).await?;
74 Ok(resps)
75 } else {
76 let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
77
78 let mut results = Vec::with_capacity(outputs.len());
79
80 for output in outputs {
81 let resp =
82 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
83 results.push(resp);
84 }
85
86 send_warning_opt(client, query_ctx).await?;
87 Ok(results)
88 }
89 }
90}
91
92async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
93where
94 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
95 C::Error: Debug,
96 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
97{
98 if let Some(warning) = query_context.warning() {
99 client
100 .feed(PgWireBackendMessage::NoticeResponse(
101 ErrorInfo::new(
102 PgErrorSeverity::Warning.to_string(),
103 PgErrorCode::Ec01000.code(),
104 warning.to_string(),
105 )
106 .into(),
107 ))
108 .await?;
109 }
110
111 Ok(())
112}
113
114pub(crate) fn output_to_query_response<'a>(
115 query_ctx: QueryContextRef,
116 output: Result<Output>,
117 field_format: &Format,
118) -> PgWireResult<Response<'a>> {
119 match output {
120 Ok(o) => match o.data {
121 OutputData::AffectedRows(rows) => {
122 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
123 }
124 OutputData::Stream(record_stream) => {
125 let schema = record_stream.schema();
126 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
127 }
128 OutputData::RecordBatches(recordbatches) => {
129 let schema = recordbatches.schema();
130 recordbatches_to_query_response(
131 query_ctx,
132 recordbatches.as_stream(),
133 schema,
134 field_format,
135 )
136 }
137 },
138 Err(e) => {
139 let status_code = e.status_code();
140
141 if status_code.should_log_error() {
142 let root_error = e.root_cause().unwrap_or(&e);
143 error!(e; "Failed to handle postgres query, code: {}, db: {}, error: {}", status_code, query_ctx.get_db_string(), root_error.to_string());
144 } else {
145 debug!(
146 "Failed to handle postgres query, code: {}, db: {}, error: {:?}",
147 status_code,
148 query_ctx.get_db_string(),
149 e
150 );
151 };
152 Ok(Response::Error(Box::new(
153 PgErrorCode::from(status_code).to_err_info(e.output_msg()),
154 )))
155 }
156 }
157}
158
159fn recordbatches_to_query_response<'a, S>(
160 query_ctx: QueryContextRef,
161 recordbatches_stream: S,
162 schema: SchemaRef,
163 field_format: &Format,
164) -> PgWireResult<Response<'a>>
165where
166 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
167{
168 let pg_schema = Arc::new(
169 schema_to_pg(schema.as_ref(), field_format)
170 .map_err(|e| PgWireError::ApiError(Box::new(e)))?,
171 );
172 let pg_schema_ref = pg_schema.clone();
173 let data_row_stream = recordbatches_stream
174 .map(|record_batch_result| match record_batch_result {
175 Ok(rb) => stream::iter(
176 rb.rows().map(Ok).collect::<Vec<_>>(),
179 )
180 .boxed(),
181 Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(),
182 })
183 .flatten() .map(move |row| {
185 row.and_then(|row| {
186 let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
187 for (value, column) in row.iter().zip(schema.column_schemas()) {
188 encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
189 }
190 encoder.finish()
191 })
192 });
193
194 Ok(Response::Query(QueryResponse::new(
195 pg_schema,
196 data_row_stream,
197 )))
198}
199
200pub struct DefaultQueryParser {
201 query_handler: ServerSqlQueryHandlerRef,
202 session: Arc<Session>,
203}
204
205impl DefaultQueryParser {
206 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
207 DefaultQueryParser {
208 query_handler,
209 session,
210 }
211 }
212}
213
214#[async_trait]
215impl QueryParser for DefaultQueryParser {
216 type Statement = SqlPlan;
217
218 async fn parse_sql<C>(
219 &self,
220 _client: &C,
221 sql: &str,
222 _types: &[Type],
223 ) -> PgWireResult<Self::Statement> {
224 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
225 let query_ctx = self.session.new_query_context();
226
227 if sql.is_empty() || fixtures::matches(sql) {
229 return Ok(SqlPlan {
230 query: sql.to_owned(),
231 plan: None,
232 schema: None,
233 });
234 }
235
236 let sql = fixtures::rewrite_sql(sql);
237 let sql = sql.as_ref();
238
239 let mut stmts =
240 ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
241 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
242 if stmts.len() != 1 {
243 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
244 PgErrorCode::Ec42P14,
245 ))))
246 } else {
247 let stmt = stmts.remove(0);
248
249 let describe_result = self
250 .query_handler
251 .do_describe(stmt, query_ctx)
252 .await
253 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
254
255 let (plan, schema) = if let Some(DescribeResult {
256 logical_plan,
257 schema,
258 }) = describe_result
259 {
260 (Some(logical_plan), Some(schema))
261 } else {
262 (None, None)
263 };
264
265 Ok(SqlPlan {
266 query: sql.to_owned(),
267 plan,
268 schema,
269 })
270 }
271 }
272}
273
274#[async_trait]
275impl ExtendedQueryHandler for PostgresServerHandlerInner {
276 type Statement = SqlPlan;
277 type QueryParser = DefaultQueryParser;
278
279 fn query_parser(&self) -> Arc<Self::QueryParser> {
280 self.query_parser.clone()
281 }
282
283 async fn do_query<'a, C>(
284 &self,
285 client: &mut C,
286 portal: &Portal<Self::Statement>,
287 _max_rows: usize,
288 ) -> PgWireResult<Response<'a>>
289 where
290 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
291 C::Error: Debug,
292 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
293 {
294 let query_ctx = self.session.new_query_context();
295 let db = query_ctx.get_db_string();
296 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
297 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
298 .start_timer();
299
300 let sql_plan = &portal.statement.statement;
301
302 if sql_plan.query.is_empty() {
303 return Ok(Response::EmptyQuery);
305 }
306
307 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
308 send_warning_opt(client, query_ctx).await?;
309 return Ok(resps.remove(0));
311 }
312
313 let output = if let Some(plan) = &sql_plan.plan {
314 let plan = plan
315 .clone()
316 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
317 plan, portal,
318 )?))
319 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
320 self.query_handler
321 .do_exec_plan(plan, query_ctx.clone())
322 .await
323 } else {
324 let mut sql = sql_plan.query.clone();
328 for i in 0..portal.parameter_len() {
329 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
330 }
331
332 self.query_handler
333 .do_query(&sql, query_ctx.clone())
334 .await
335 .remove(0)
336 };
337
338 send_warning_opt(client, query_ctx.clone()).await?;
339 output_to_query_response(query_ctx, output, &portal.result_column_format)
340 }
341
342 async fn do_describe_statement<C>(
343 &self,
344 _client: &mut C,
345 stmt: &StoredStatement<Self::Statement>,
346 ) -> PgWireResult<DescribeStatementResponse>
347 where
348 C: ClientInfo + Unpin + Send + Sync,
349 {
350 let sql_plan = &stmt.statement;
351 let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
352 let param_types = plan
353 .get_parameter_types()
354 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
355 .into_iter()
356 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
357 .collect();
358
359 let types = param_types_to_pg_types(¶m_types)
360 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
361
362 (types, sql_plan, &Format::UnifiedBinary)
363 } else {
364 let param_types = stmt.parameter_types.clone();
365 (param_types, sql_plan, &Format::UnifiedBinary)
366 };
367
368 if let Some(schema) = &sql_plan.schema {
369 schema_to_pg(schema, format)
370 .map(|fields| DescribeStatementResponse::new(param_types, fields))
371 .map_err(|e| PgWireError::ApiError(Box::new(e)))
372 } else {
373 if let Some(mut resp) =
374 fixtures::process(&sql_plan.query, self.session.new_query_context())
375 {
376 if let Response::Query(query_response) = resp.remove(0) {
377 return Ok(DescribeStatementResponse::new(
378 param_types,
379 (*query_response.row_schema()).clone(),
380 ));
381 }
382 }
383
384 Ok(DescribeStatementResponse::new(param_types, vec![]))
385 }
386 }
387
388 async fn do_describe_portal<C>(
389 &self,
390 _client: &mut C,
391 portal: &Portal<Self::Statement>,
392 ) -> PgWireResult<DescribePortalResponse>
393 where
394 C: ClientInfo + Unpin + Send + Sync,
395 {
396 let sql_plan = &portal.statement.statement;
397 let format = &portal.result_column_format;
398
399 if let Some(schema) = &sql_plan.schema {
400 schema_to_pg(schema, format)
401 .map(DescribePortalResponse::new)
402 .map_err(|e| PgWireError::ApiError(Box::new(e)))
403 } else {
404 if let Some(mut resp) =
405 fixtures::process(&sql_plan.query, self.session.new_query_context())
406 {
407 if let Response::Query(query_response) = resp.remove(0) {
408 return Ok(DescribePortalResponse::new(
409 (*query_response.row_schema()).clone(),
410 ));
411 }
412 }
413
414 Ok(DescribePortalResponse::new(vec![]))
415 }
416 }
417}
418
419impl ErrorHandler for PostgresServerHandlerInner {
420 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
421 where
422 C: ClientInfo,
423 {
424 debug!("Postgres interface error {}", error)
425 }
426}