servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_error::ext::ErrorExt;
20use common_query::{Output, OutputData};
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::RecordBatch;
23use common_telemetry::{debug, error, tracing};
24use datafusion_common::ParamValues;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::context::QueryContextRef;
39use session::Session;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::Result;
44use crate::postgres::types::*;
45use crate::postgres::{fixtures, PostgresServerHandlerInner};
46use crate::query_handler::sql::ServerSqlQueryHandlerRef;
47use crate::SqlPlan;
48
49#[async_trait]
50impl SimpleQueryHandler for PostgresServerHandlerInner {
51    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
52    async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
53    where
54        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
55        C::Error: Debug,
56        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
57    {
58        let query_ctx = self.session.new_query_context();
59        let db = query_ctx.get_db_string();
60        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
61            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
62            .start_timer();
63
64        if query.is_empty() {
65            // early return if query is empty
66            return Ok(vec![Response::EmptyQuery]);
67        }
68
69        let query = fixtures::rewrite_sql(query);
70        let query = query.as_ref();
71
72        if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
73            send_warning_opt(client, query_ctx).await?;
74            Ok(resps)
75        } else {
76            let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
77
78            let mut results = Vec::with_capacity(outputs.len());
79
80            for output in outputs {
81                let resp =
82                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
83                results.push(resp);
84            }
85
86            send_warning_opt(client, query_ctx).await?;
87            Ok(results)
88        }
89    }
90}
91
92async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
93where
94    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
95    C::Error: Debug,
96    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
97{
98    if let Some(warning) = query_context.warning() {
99        client
100            .feed(PgWireBackendMessage::NoticeResponse(
101                ErrorInfo::new(
102                    PgErrorSeverity::Warning.to_string(),
103                    PgErrorCode::Ec01000.code(),
104                    warning.to_string(),
105                )
106                .into(),
107            ))
108            .await?;
109    }
110
111    Ok(())
112}
113
114pub(crate) fn output_to_query_response<'a>(
115    query_ctx: QueryContextRef,
116    output: Result<Output>,
117    field_format: &Format,
118) -> PgWireResult<Response<'a>> {
119    match output {
120        Ok(o) => match o.data {
121            OutputData::AffectedRows(rows) => {
122                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
123            }
124            OutputData::Stream(record_stream) => {
125                let schema = record_stream.schema();
126                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
127            }
128            OutputData::RecordBatches(recordbatches) => {
129                let schema = recordbatches.schema();
130                recordbatches_to_query_response(
131                    query_ctx,
132                    recordbatches.as_stream(),
133                    schema,
134                    field_format,
135                )
136            }
137        },
138        Err(e) => {
139            let status_code = e.status_code();
140
141            if status_code.should_log_error() {
142                let root_error = e.root_cause().unwrap_or(&e);
143                error!(e; "Failed to handle postgres query, code: {}, db: {}, error: {}", status_code, query_ctx.get_db_string(), root_error.to_string());
144            } else {
145                debug!(
146                    "Failed to handle postgres query, code: {}, db: {}, error: {:?}",
147                    status_code,
148                    query_ctx.get_db_string(),
149                    e
150                );
151            };
152            Ok(Response::Error(Box::new(
153                PgErrorCode::from(status_code).to_err_info(e.output_msg()),
154            )))
155        }
156    }
157}
158
159fn recordbatches_to_query_response<'a, S>(
160    query_ctx: QueryContextRef,
161    recordbatches_stream: S,
162    schema: SchemaRef,
163    field_format: &Format,
164) -> PgWireResult<Response<'a>>
165where
166    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
167{
168    let pg_schema = Arc::new(
169        schema_to_pg(schema.as_ref(), field_format)
170            .map_err(|e| PgWireError::ApiError(Box::new(e)))?,
171    );
172    let pg_schema_ref = pg_schema.clone();
173    let data_row_stream = recordbatches_stream
174        .map(|record_batch_result| match record_batch_result {
175            Ok(rb) => stream::iter(
176                // collect rows from a single recordbatch into vector to avoid
177                // borrowing it
178                rb.rows().map(Ok).collect::<Vec<_>>(),
179            )
180            .boxed(),
181            Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(),
182        })
183        .flatten() // flatten into stream<result<row>>
184        .map(move |row| {
185            row.and_then(|row| {
186                let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
187                for (value, column) in row.iter().zip(schema.column_schemas()) {
188                    encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
189                }
190                encoder.finish()
191            })
192        });
193
194    Ok(Response::Query(QueryResponse::new(
195        pg_schema,
196        data_row_stream,
197    )))
198}
199
200pub struct DefaultQueryParser {
201    query_handler: ServerSqlQueryHandlerRef,
202    session: Arc<Session>,
203}
204
205impl DefaultQueryParser {
206    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
207        DefaultQueryParser {
208            query_handler,
209            session,
210        }
211    }
212}
213
214#[async_trait]
215impl QueryParser for DefaultQueryParser {
216    type Statement = SqlPlan;
217
218    async fn parse_sql<C>(
219        &self,
220        _client: &C,
221        sql: &str,
222        _types: &[Type],
223    ) -> PgWireResult<Self::Statement> {
224        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
225        let query_ctx = self.session.new_query_context();
226
227        // do not parse if query is empty or matches rules
228        if sql.is_empty() || fixtures::matches(sql) {
229            return Ok(SqlPlan {
230                query: sql.to_owned(),
231                plan: None,
232                schema: None,
233            });
234        }
235
236        let sql = fixtures::rewrite_sql(sql);
237        let sql = sql.as_ref();
238
239        let mut stmts =
240            ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
241                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
242        if stmts.len() != 1 {
243            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
244                PgErrorCode::Ec42P14,
245            ))))
246        } else {
247            let stmt = stmts.remove(0);
248
249            let describe_result = self
250                .query_handler
251                .do_describe(stmt, query_ctx)
252                .await
253                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
254
255            let (plan, schema) = if let Some(DescribeResult {
256                logical_plan,
257                schema,
258            }) = describe_result
259            {
260                (Some(logical_plan), Some(schema))
261            } else {
262                (None, None)
263            };
264
265            Ok(SqlPlan {
266                query: sql.to_owned(),
267                plan,
268                schema,
269            })
270        }
271    }
272}
273
274#[async_trait]
275impl ExtendedQueryHandler for PostgresServerHandlerInner {
276    type Statement = SqlPlan;
277    type QueryParser = DefaultQueryParser;
278
279    fn query_parser(&self) -> Arc<Self::QueryParser> {
280        self.query_parser.clone()
281    }
282
283    async fn do_query<'a, C>(
284        &self,
285        client: &mut C,
286        portal: &Portal<Self::Statement>,
287        _max_rows: usize,
288    ) -> PgWireResult<Response<'a>>
289    where
290        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
291        C::Error: Debug,
292        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
293    {
294        let query_ctx = self.session.new_query_context();
295        let db = query_ctx.get_db_string();
296        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
297            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
298            .start_timer();
299
300        let sql_plan = &portal.statement.statement;
301
302        if sql_plan.query.is_empty() {
303            // early return if query is empty
304            return Ok(Response::EmptyQuery);
305        }
306
307        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
308            send_warning_opt(client, query_ctx).await?;
309            // if the statement matches our predefined rules, return it early
310            return Ok(resps.remove(0));
311        }
312
313        let output = if let Some(plan) = &sql_plan.plan {
314            let plan = plan
315                .clone()
316                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
317                    plan, portal,
318                )?))
319                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
320            self.query_handler
321                .do_exec_plan(plan, query_ctx.clone())
322                .await
323        } else {
324            // manually replace variables in prepared statement when no
325            // logical_plan is generated. This happens when logical plan is not
326            // supported for certain statements.
327            let mut sql = sql_plan.query.clone();
328            for i in 0..portal.parameter_len() {
329                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
330            }
331
332            self.query_handler
333                .do_query(&sql, query_ctx.clone())
334                .await
335                .remove(0)
336        };
337
338        send_warning_opt(client, query_ctx.clone()).await?;
339        output_to_query_response(query_ctx, output, &portal.result_column_format)
340    }
341
342    async fn do_describe_statement<C>(
343        &self,
344        _client: &mut C,
345        stmt: &StoredStatement<Self::Statement>,
346    ) -> PgWireResult<DescribeStatementResponse>
347    where
348        C: ClientInfo + Unpin + Send + Sync,
349    {
350        let sql_plan = &stmt.statement;
351        let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
352            let param_types = plan
353                .get_parameter_types()
354                .map_err(|e| PgWireError::ApiError(Box::new(e)))?
355                .into_iter()
356                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
357                .collect();
358
359            let types = param_types_to_pg_types(&param_types)
360                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
361
362            (types, sql_plan, &Format::UnifiedBinary)
363        } else {
364            let param_types = stmt.parameter_types.clone();
365            (param_types, sql_plan, &Format::UnifiedBinary)
366        };
367
368        if let Some(schema) = &sql_plan.schema {
369            schema_to_pg(schema, format)
370                .map(|fields| DescribeStatementResponse::new(param_types, fields))
371                .map_err(|e| PgWireError::ApiError(Box::new(e)))
372        } else {
373            if let Some(mut resp) =
374                fixtures::process(&sql_plan.query, self.session.new_query_context())
375            {
376                if let Response::Query(query_response) = resp.remove(0) {
377                    return Ok(DescribeStatementResponse::new(
378                        param_types,
379                        (*query_response.row_schema()).clone(),
380                    ));
381                }
382            }
383
384            Ok(DescribeStatementResponse::new(param_types, vec![]))
385        }
386    }
387
388    async fn do_describe_portal<C>(
389        &self,
390        _client: &mut C,
391        portal: &Portal<Self::Statement>,
392    ) -> PgWireResult<DescribePortalResponse>
393    where
394        C: ClientInfo + Unpin + Send + Sync,
395    {
396        let sql_plan = &portal.statement.statement;
397        let format = &portal.result_column_format;
398
399        if let Some(schema) = &sql_plan.schema {
400            schema_to_pg(schema, format)
401                .map(DescribePortalResponse::new)
402                .map_err(|e| PgWireError::ApiError(Box::new(e)))
403        } else {
404            if let Some(mut resp) =
405                fixtures::process(&sql_plan.query, self.session.new_query_context())
406            {
407                if let Response::Query(query_response) = resp.remove(0) {
408                    return Ok(DescribePortalResponse::new(
409                        (*query_response.row_schema()).clone(),
410                    ));
411                }
412            }
413
414            Ok(DescribePortalResponse::new(vec![]))
415        }
416    }
417}
418
419impl ErrorHandler for PostgresServerHandlerInner {
420    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
421    where
422        C: ClientInfo,
423    {
424        debug!("Postgres interface error {}", error)
425    }
426}