servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_error::ext::ErrorExt;
20use common_query::{Output, OutputData};
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::RecordBatch;
23use common_telemetry::{debug, error, tracing};
24use datafusion_common::ParamValues;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::context::QueryContextRef;
39use session::Session;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::Result;
44use crate::postgres::types::*;
45use crate::postgres::{fixtures, PostgresServerHandlerInner};
46use crate::query_handler::sql::ServerSqlQueryHandlerRef;
47use crate::SqlPlan;
48
49#[async_trait]
50impl SimpleQueryHandler for PostgresServerHandlerInner {
51    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
52    async fn do_query<'a, C>(
53        &self,
54        client: &mut C,
55        query: &'a str,
56    ) -> PgWireResult<Vec<Response<'a>>>
57    where
58        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
59        C::Error: Debug,
60        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
61    {
62        let query_ctx = self.session.new_query_context();
63        let db = query_ctx.get_db_string();
64        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
65            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
66            .start_timer();
67
68        if query.is_empty() {
69            // early return if query is empty
70            return Ok(vec![Response::EmptyQuery]);
71        }
72
73        let query = fixtures::rewrite_sql(query);
74        let query = query.as_ref();
75
76        if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
77            send_warning_opt(client, query_ctx).await?;
78            Ok(resps)
79        } else {
80            let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
81
82            let mut results = Vec::with_capacity(outputs.len());
83
84            for output in outputs {
85                let resp =
86                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
87                results.push(resp);
88            }
89
90            send_warning_opt(client, query_ctx).await?;
91            Ok(results)
92        }
93    }
94}
95
96async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
97where
98    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
99    C::Error: Debug,
100    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
101{
102    if let Some(warning) = query_context.warning() {
103        client
104            .feed(PgWireBackendMessage::NoticeResponse(
105                ErrorInfo::new(
106                    PgErrorSeverity::Warning.to_string(),
107                    PgErrorCode::Ec01000.code(),
108                    warning.to_string(),
109                )
110                .into(),
111            ))
112            .await?;
113    }
114
115    Ok(())
116}
117
118pub(crate) fn output_to_query_response<'a>(
119    query_ctx: QueryContextRef,
120    output: Result<Output>,
121    field_format: &Format,
122) -> PgWireResult<Response<'a>> {
123    match output {
124        Ok(o) => match o.data {
125            OutputData::AffectedRows(rows) => {
126                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
127            }
128            OutputData::Stream(record_stream) => {
129                let schema = record_stream.schema();
130                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
131            }
132            OutputData::RecordBatches(recordbatches) => {
133                let schema = recordbatches.schema();
134                recordbatches_to_query_response(
135                    query_ctx,
136                    recordbatches.as_stream(),
137                    schema,
138                    field_format,
139                )
140            }
141        },
142        Err(e) => {
143            let status_code = e.status_code();
144
145            if status_code.should_log_error() {
146                let root_error = e.root_cause().unwrap_or(&e);
147                error!(e; "Failed to handle postgres query, code: {}, db: {}, error: {}", status_code, query_ctx.get_db_string(), root_error.to_string());
148            } else {
149                debug!(
150                    "Failed to handle postgres query, code: {}, db: {}, error: {:?}",
151                    status_code,
152                    query_ctx.get_db_string(),
153                    e
154                );
155            };
156            Ok(Response::Error(Box::new(
157                PgErrorCode::from(status_code).to_err_info(e.output_msg()),
158            )))
159        }
160    }
161}
162
163fn recordbatches_to_query_response<'a, S>(
164    query_ctx: QueryContextRef,
165    recordbatches_stream: S,
166    schema: SchemaRef,
167    field_format: &Format,
168) -> PgWireResult<Response<'a>>
169where
170    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
171{
172    let pg_schema = Arc::new(
173        schema_to_pg(schema.as_ref(), field_format)
174            .map_err(|e| PgWireError::ApiError(Box::new(e)))?,
175    );
176    let pg_schema_ref = pg_schema.clone();
177    let data_row_stream = recordbatches_stream
178        .map(|record_batch_result| match record_batch_result {
179            Ok(rb) => stream::iter(
180                // collect rows from a single recordbatch into vector to avoid
181                // borrowing it
182                rb.rows().map(Ok).collect::<Vec<_>>(),
183            )
184            .boxed(),
185            Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(),
186        })
187        .flatten() // flatten into stream<result<row>>
188        .map(move |row| {
189            row.and_then(|row| {
190                let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
191                for (value, column) in row.iter().zip(schema.column_schemas()) {
192                    encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
193                }
194                encoder.finish()
195            })
196        });
197
198    Ok(Response::Query(QueryResponse::new(
199        pg_schema,
200        data_row_stream,
201    )))
202}
203
204pub struct DefaultQueryParser {
205    query_handler: ServerSqlQueryHandlerRef,
206    session: Arc<Session>,
207}
208
209impl DefaultQueryParser {
210    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
211        DefaultQueryParser {
212            query_handler,
213            session,
214        }
215    }
216}
217
218#[async_trait]
219impl QueryParser for DefaultQueryParser {
220    type Statement = SqlPlan;
221
222    async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
223        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
224        let query_ctx = self.session.new_query_context();
225
226        // do not parse if query is empty or matches rules
227        if sql.is_empty() || fixtures::matches(sql) {
228            return Ok(SqlPlan {
229                query: sql.to_owned(),
230                plan: None,
231                schema: None,
232            });
233        }
234
235        let sql = fixtures::rewrite_sql(sql);
236        let sql = sql.as_ref();
237
238        let mut stmts =
239            ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
240                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
241        if stmts.len() != 1 {
242            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
243                PgErrorCode::Ec42P14,
244            ))))
245        } else {
246            let stmt = stmts.remove(0);
247
248            let describe_result = self
249                .query_handler
250                .do_describe(stmt, query_ctx)
251                .await
252                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
253
254            let (plan, schema) = if let Some(DescribeResult {
255                logical_plan,
256                schema,
257            }) = describe_result
258            {
259                (Some(logical_plan), Some(schema))
260            } else {
261                (None, None)
262            };
263
264            Ok(SqlPlan {
265                query: sql.to_owned(),
266                plan,
267                schema,
268            })
269        }
270    }
271}
272
273#[async_trait]
274impl ExtendedQueryHandler for PostgresServerHandlerInner {
275    type Statement = SqlPlan;
276    type QueryParser = DefaultQueryParser;
277
278    fn query_parser(&self) -> Arc<Self::QueryParser> {
279        self.query_parser.clone()
280    }
281
282    async fn do_query<'a, C>(
283        &self,
284        client: &mut C,
285        portal: &'a Portal<Self::Statement>,
286        _max_rows: usize,
287    ) -> PgWireResult<Response<'a>>
288    where
289        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
290        C::Error: Debug,
291        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
292    {
293        let query_ctx = self.session.new_query_context();
294        let db = query_ctx.get_db_string();
295        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
296            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
297            .start_timer();
298
299        let sql_plan = &portal.statement.statement;
300
301        if sql_plan.query.is_empty() {
302            // early return if query is empty
303            return Ok(Response::EmptyQuery);
304        }
305
306        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
307            send_warning_opt(client, query_ctx).await?;
308            // if the statement matches our predefined rules, return it early
309            return Ok(resps.remove(0));
310        }
311
312        let output = if let Some(plan) = &sql_plan.plan {
313            let plan = plan
314                .clone()
315                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
316                    plan, portal,
317                )?))
318                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
319            self.query_handler
320                .do_exec_plan(plan, query_ctx.clone())
321                .await
322        } else {
323            // manually replace variables in prepared statement when no
324            // logical_plan is generated. This happens when logical plan is not
325            // supported for certain statements.
326            let mut sql = sql_plan.query.clone();
327            for i in 0..portal.parameter_len() {
328                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
329            }
330
331            self.query_handler
332                .do_query(&sql, query_ctx.clone())
333                .await
334                .remove(0)
335        };
336
337        send_warning_opt(client, query_ctx.clone()).await?;
338        output_to_query_response(query_ctx, output, &portal.result_column_format)
339    }
340
341    async fn do_describe_statement<C>(
342        &self,
343        _client: &mut C,
344        stmt: &StoredStatement<Self::Statement>,
345    ) -> PgWireResult<DescribeStatementResponse>
346    where
347        C: ClientInfo + Unpin + Send + Sync,
348    {
349        let sql_plan = &stmt.statement;
350        let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
351            let param_types = plan
352                .get_parameter_types()
353                .map_err(|e| PgWireError::ApiError(Box::new(e)))?
354                .into_iter()
355                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
356                .collect();
357
358            let types = param_types_to_pg_types(&param_types)
359                .map_err(|e| PgWireError::ApiError(Box::new(e)))?;
360
361            (types, sql_plan, &Format::UnifiedBinary)
362        } else {
363            let param_types = stmt.parameter_types.clone();
364            (param_types, sql_plan, &Format::UnifiedBinary)
365        };
366
367        if let Some(schema) = &sql_plan.schema {
368            schema_to_pg(schema, format)
369                .map(|fields| DescribeStatementResponse::new(param_types, fields))
370                .map_err(|e| PgWireError::ApiError(Box::new(e)))
371        } else {
372            if let Some(mut resp) =
373                fixtures::process(&sql_plan.query, self.session.new_query_context())
374            {
375                if let Response::Query(query_response) = resp.remove(0) {
376                    return Ok(DescribeStatementResponse::new(
377                        param_types,
378                        (*query_response.row_schema()).clone(),
379                    ));
380                }
381            }
382
383            Ok(DescribeStatementResponse::new(param_types, vec![]))
384        }
385    }
386
387    async fn do_describe_portal<C>(
388        &self,
389        _client: &mut C,
390        portal: &Portal<Self::Statement>,
391    ) -> PgWireResult<DescribePortalResponse>
392    where
393        C: ClientInfo + Unpin + Send + Sync,
394    {
395        let sql_plan = &portal.statement.statement;
396        let format = &portal.result_column_format;
397
398        if let Some(schema) = &sql_plan.schema {
399            schema_to_pg(schema, format)
400                .map(DescribePortalResponse::new)
401                .map_err(|e| PgWireError::ApiError(Box::new(e)))
402        } else {
403            if let Some(mut resp) =
404                fixtures::process(&sql_plan.query, self.session.new_query_context())
405            {
406                if let Response::Query(query_response) = resp.remove(0) {
407                    return Ok(DescribePortalResponse::new(
408                        (*query_response.row_schema()).clone(),
409                    ));
410                }
411            }
412
413            Ok(DescribePortalResponse::new(vec![]))
414        }
415    }
416}
417
418impl ErrorHandler for PostgresServerHandlerInner {
419    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
420    where
421        C: ClientInfo,
422    {
423        debug!("Postgres interface error {}", error)
424    }
425}