1use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_error::ext::ErrorExt;
20use common_query::{Output, OutputData};
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::RecordBatch;
23use common_telemetry::{debug, error, tracing};
24use datafusion_common::ParamValues;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::context::QueryContextRef;
39use session::Session;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::Result;
44use crate::postgres::types::*;
45use crate::postgres::{fixtures, PostgresServerHandlerInner};
46use crate::query_handler::sql::ServerSqlQueryHandlerRef;
47use crate::SqlPlan;
48
49#[async_trait]
50impl SimpleQueryHandler for PostgresServerHandlerInner {
51 #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
52 async fn do_query<'a, C>(
53 &self,
54 client: &mut C,
55 query: &'a str,
56 ) -> PgWireResult<Vec<Response<'a>>>
57 where
58 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
59 C::Error: Debug,
60 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
61 {
62 let query_ctx = self.session.new_query_context();
63 let db = query_ctx.get_db_string();
64 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
65 .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
66 .start_timer();
67
68 if query.is_empty() {
69 return Ok(vec![Response::EmptyQuery]);
71 }
72
73 let query = fixtures::rewrite_sql(query);
74 let query = query.as_ref();
75
76 if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
77 send_warning_opt(client, query_ctx).await?;
78 Ok(resps)
79 } else {
80 let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
81
82 let mut results = Vec::with_capacity(outputs.len());
83
84 for output in outputs {
85 let resp =
86 output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
87 results.push(resp);
88 }
89
90 send_warning_opt(client, query_ctx).await?;
91 Ok(results)
92 }
93 }
94}
95
96async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
97where
98 C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
99 C::Error: Debug,
100 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
101{
102 if let Some(warning) = query_context.warning() {
103 client
104 .feed(PgWireBackendMessage::NoticeResponse(
105 ErrorInfo::new(
106 PgErrorSeverity::Warning.to_string(),
107 PgErrorCode::Ec01000.code(),
108 warning.to_string(),
109 )
110 .into(),
111 ))
112 .await?;
113 }
114
115 Ok(())
116}
117
118pub(crate) fn output_to_query_response<'a>(
119 query_ctx: QueryContextRef,
120 output: Result<Output>,
121 field_format: &Format,
122) -> PgWireResult<Response<'a>> {
123 match output {
124 Ok(o) => match o.data {
125 OutputData::AffectedRows(rows) => {
126 Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
127 }
128 OutputData::Stream(record_stream) => {
129 let schema = record_stream.schema();
130 recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
131 }
132 OutputData::RecordBatches(recordbatches) => {
133 let schema = recordbatches.schema();
134 recordbatches_to_query_response(
135 query_ctx,
136 recordbatches.as_stream(),
137 schema,
138 field_format,
139 )
140 }
141 },
142 Err(e) => {
143 let status_code = e.status_code();
144
145 if status_code.should_log_error() {
146 let root_error = e.root_cause().unwrap_or(&e);
147 error!(e; "Failed to handle postgres query, code: {}, db: {}, error: {}", status_code, query_ctx.get_db_string(), root_error.to_string());
148 } else {
149 debug!(
150 "Failed to handle postgres query, code: {}, db: {}, error: {:?}",
151 status_code,
152 query_ctx.get_db_string(),
153 e
154 );
155 };
156 Ok(Response::Error(Box::new(
157 PgErrorCode::from(status_code).to_err_info(e.output_msg()),
158 )))
159 }
160 }
161}
162
163fn recordbatches_to_query_response<'a, S>(
164 query_ctx: QueryContextRef,
165 recordbatches_stream: S,
166 schema: SchemaRef,
167 field_format: &Format,
168) -> PgWireResult<Response<'a>>
169where
170 S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
171{
172 let pg_schema = Arc::new(
173 schema_to_pg(schema.as_ref(), field_format)
174 .map_err(|e| PgWireError::ApiError(Box::new(e)))?,
175 );
176 let pg_schema_ref = pg_schema.clone();
177 let data_row_stream = recordbatches_stream
178 .map(|record_batch_result| match record_batch_result {
179 Ok(rb) => stream::iter(
180 rb.rows().map(Ok).collect::<Vec<_>>(),
183 )
184 .boxed(),
185 Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(),
186 })
187 .flatten() .map(move |row| {
189 row.and_then(|row| {
190 let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
191 for (value, column) in row.iter().zip(schema.column_schemas()) {
192 encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
193 }
194 encoder.finish()
195 })
196 });
197
198 Ok(Response::Query(QueryResponse::new(
199 pg_schema,
200 data_row_stream,
201 )))
202}
203
204pub struct DefaultQueryParser {
205 query_handler: ServerSqlQueryHandlerRef,
206 session: Arc<Session>,
207}
208
209impl DefaultQueryParser {
210 pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
211 DefaultQueryParser {
212 query_handler,
213 session,
214 }
215 }
216}
217
218#[async_trait]
219impl QueryParser for DefaultQueryParser {
220 type Statement = SqlPlan;
221
222 async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
223 crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
224 let query_ctx = self.session.new_query_context();
225
226 if sql.is_empty() || fixtures::matches(sql) {
228 return Ok(SqlPlan {
229 query: sql.to_owned(),
230 plan: None,
231 schema: None,
232 });
233 }
234
235 let sql = fixtures::rewrite_sql(sql);
236 let sql = sql.as_ref();
237
238 let mut stmts =
239 ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
240 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
241 if stmts.len() != 1 {
242 Err(PgWireError::UserError(Box::new(ErrorInfo::from(
243 PgErrorCode::Ec42P14,
244 ))))
245 } else {
246 let stmt = stmts.remove(0);
247
248 let describe_result = self
249 .query_handler
250 .do_describe(stmt, query_ctx)
251 .await
252 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
253
254 let (plan, schema) = if let Some(DescribeResult {
255 logical_plan,
256 schema,
257 }) = describe_result
258 {
259 (Some(logical_plan), Some(schema))
260 } else {
261 (None, None)
262 };
263
264 Ok(SqlPlan {
265 query: sql.to_owned(),
266 plan,
267 schema,
268 })
269 }
270 }
271}
272
273#[async_trait]
274impl ExtendedQueryHandler for PostgresServerHandlerInner {
275 type Statement = SqlPlan;
276 type QueryParser = DefaultQueryParser;
277
278 fn query_parser(&self) -> Arc<Self::QueryParser> {
279 self.query_parser.clone()
280 }
281
282 async fn do_query<'a, C>(
283 &self,
284 client: &mut C,
285 portal: &'a Portal<Self::Statement>,
286 _max_rows: usize,
287 ) -> PgWireResult<Response<'a>>
288 where
289 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
290 C::Error: Debug,
291 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
292 {
293 let query_ctx = self.session.new_query_context();
294 let db = query_ctx.get_db_string();
295 let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
296 .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
297 .start_timer();
298
299 let sql_plan = &portal.statement.statement;
300
301 if sql_plan.query.is_empty() {
302 return Ok(Response::EmptyQuery);
304 }
305
306 if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
307 send_warning_opt(client, query_ctx).await?;
308 return Ok(resps.remove(0));
310 }
311
312 let output = if let Some(plan) = &sql_plan.plan {
313 let plan = plan
314 .clone()
315 .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
316 plan, portal,
317 )?))
318 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
319 self.query_handler
320 .do_exec_plan(plan, query_ctx.clone())
321 .await
322 } else {
323 let mut sql = sql_plan.query.clone();
327 for i in 0..portal.parameter_len() {
328 sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
329 }
330
331 self.query_handler
332 .do_query(&sql, query_ctx.clone())
333 .await
334 .remove(0)
335 };
336
337 send_warning_opt(client, query_ctx.clone()).await?;
338 output_to_query_response(query_ctx, output, &portal.result_column_format)
339 }
340
341 async fn do_describe_statement<C>(
342 &self,
343 _client: &mut C,
344 stmt: &StoredStatement<Self::Statement>,
345 ) -> PgWireResult<DescribeStatementResponse>
346 where
347 C: ClientInfo + Unpin + Send + Sync,
348 {
349 let sql_plan = &stmt.statement;
350 let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
351 let param_types = plan
352 .get_parameter_types()
353 .map_err(|e| PgWireError::ApiError(Box::new(e)))?
354 .into_iter()
355 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
356 .collect();
357
358 let types = param_types_to_pg_types(¶m_types)
359 .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
360
361 (types, sql_plan, &Format::UnifiedBinary)
362 } else {
363 let param_types = stmt.parameter_types.clone();
364 (param_types, sql_plan, &Format::UnifiedBinary)
365 };
366
367 if let Some(schema) = &sql_plan.schema {
368 schema_to_pg(schema, format)
369 .map(|fields| DescribeStatementResponse::new(param_types, fields))
370 .map_err(|e| PgWireError::ApiError(Box::new(e)))
371 } else {
372 if let Some(mut resp) =
373 fixtures::process(&sql_plan.query, self.session.new_query_context())
374 {
375 if let Response::Query(query_response) = resp.remove(0) {
376 return Ok(DescribeStatementResponse::new(
377 param_types,
378 (*query_response.row_schema()).clone(),
379 ));
380 }
381 }
382
383 Ok(DescribeStatementResponse::new(param_types, vec![]))
384 }
385 }
386
387 async fn do_describe_portal<C>(
388 &self,
389 _client: &mut C,
390 portal: &Portal<Self::Statement>,
391 ) -> PgWireResult<DescribePortalResponse>
392 where
393 C: ClientInfo + Unpin + Send + Sync,
394 {
395 let sql_plan = &portal.statement.statement;
396 let format = &portal.result_column_format;
397
398 if let Some(schema) = &sql_plan.schema {
399 schema_to_pg(schema, format)
400 .map(DescribePortalResponse::new)
401 .map_err(|e| PgWireError::ApiError(Box::new(e)))
402 } else {
403 if let Some(mut resp) =
404 fixtures::process(&sql_plan.query, self.session.new_query_context())
405 {
406 if let Response::Query(query_response) = resp.remove(0) {
407 return Ok(DescribePortalResponse::new(
408 (*query_response.row_schema()).clone(),
409 ));
410 }
411 }
412
413 Ok(DescribePortalResponse::new(vec![]))
414 }
415 }
416}
417
418impl ErrorHandler for PostgresServerHandlerInner {
419 fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
420 where
421 C: ClientInfo,
422 {
423 debug!("Postgres interface error {}", error)
424 }
425}