servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::error::Result as RecordBatchResult;
21use common_recordbatch::RecordBatch;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::schema::SchemaRef;
26use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
27use pgwire::api::portal::{Format, Portal};
28use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
29use pgwire::api::results::{
30    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
31};
32use pgwire::api::stmt::{QueryParser, StoredStatement};
33use pgwire::api::{ClientInfo, ErrorHandler, Type};
34use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
35use pgwire::messages::PgWireBackendMessage;
36use query::query_engine::DescribeResult;
37use session::context::QueryContextRef;
38use session::Session;
39use snafu::ResultExt;
40use sql::dialect::PostgreSqlDialect;
41use sql::parser::{ParseOptions, ParserContext};
42
43use crate::error::{DataFusionSnafu, Result};
44use crate::postgres::types::*;
45use crate::postgres::utils::convert_err;
46use crate::postgres::{fixtures, PostgresServerHandlerInner};
47use crate::query_handler::sql::ServerSqlQueryHandlerRef;
48use crate::SqlPlan;
49
50#[async_trait]
51impl SimpleQueryHandler for PostgresServerHandlerInner {
52    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
53    async fn do_query<'a, C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response<'a>>>
54    where
55        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
56        C::Error: Debug,
57        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
58    {
59        let query_ctx = self.session.new_query_context();
60        let db = query_ctx.get_db_string();
61        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
62            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
63            .start_timer();
64
65        if query.is_empty() {
66            // early return if query is empty
67            return Ok(vec![Response::EmptyQuery]);
68        }
69
70        let query = fixtures::rewrite_sql(query);
71        let query = query.as_ref();
72
73        if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
74            send_warning_opt(client, query_ctx).await?;
75            Ok(resps)
76        } else {
77            let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
78
79            let mut results = Vec::with_capacity(outputs.len());
80
81            for output in outputs {
82                let resp =
83                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
84                results.push(resp);
85            }
86
87            send_warning_opt(client, query_ctx).await?;
88            Ok(results)
89        }
90    }
91}
92
93async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
94where
95    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
96    C::Error: Debug,
97    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
98{
99    if let Some(warning) = query_context.warning() {
100        client
101            .feed(PgWireBackendMessage::NoticeResponse(
102                ErrorInfo::new(
103                    PgErrorSeverity::Warning.to_string(),
104                    PgErrorCode::Ec01000.code(),
105                    warning.to_string(),
106                )
107                .into(),
108            ))
109            .await?;
110    }
111
112    Ok(())
113}
114
115pub(crate) fn output_to_query_response<'a>(
116    query_ctx: QueryContextRef,
117    output: Result<Output>,
118    field_format: &Format,
119) -> PgWireResult<Response<'a>> {
120    match output {
121        Ok(o) => match o.data {
122            OutputData::AffectedRows(rows) => {
123                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
124            }
125            OutputData::Stream(record_stream) => {
126                let schema = record_stream.schema();
127                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
128            }
129            OutputData::RecordBatches(recordbatches) => {
130                let schema = recordbatches.schema();
131                recordbatches_to_query_response(
132                    query_ctx,
133                    recordbatches.as_stream(),
134                    schema,
135                    field_format,
136                )
137            }
138        },
139        Err(e) => Err(convert_err(e)),
140    }
141}
142
143fn recordbatches_to_query_response<'a, S>(
144    query_ctx: QueryContextRef,
145    recordbatches_stream: S,
146    schema: SchemaRef,
147    field_format: &Format,
148) -> PgWireResult<Response<'a>>
149where
150    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
151{
152    let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
153    let pg_schema_ref = pg_schema.clone();
154    let data_row_stream = recordbatches_stream
155        .map(|record_batch_result| match record_batch_result {
156            Ok(rb) => stream::iter(
157                // collect rows from a single recordbatch into vector to avoid
158                // borrowing it
159                rb.rows().map(Ok).collect::<Vec<_>>(),
160            )
161            .boxed(),
162            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
163        })
164        .flatten() // flatten into stream<result<row>>
165        .map(move |row| {
166            row.and_then(|row| {
167                let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
168                for (value, column) in row.iter().zip(schema.column_schemas()) {
169                    encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
170                }
171                encoder.finish()
172            })
173        });
174
175    Ok(Response::Query(QueryResponse::new(
176        pg_schema,
177        data_row_stream,
178    )))
179}
180
181pub struct DefaultQueryParser {
182    query_handler: ServerSqlQueryHandlerRef,
183    session: Arc<Session>,
184}
185
186impl DefaultQueryParser {
187    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
188        DefaultQueryParser {
189            query_handler,
190            session,
191        }
192    }
193}
194
195#[async_trait]
196impl QueryParser for DefaultQueryParser {
197    type Statement = SqlPlan;
198
199    async fn parse_sql<C>(
200        &self,
201        _client: &C,
202        sql: &str,
203        _types: &[Type],
204    ) -> PgWireResult<Self::Statement> {
205        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
206        let query_ctx = self.session.new_query_context();
207
208        // do not parse if query is empty or matches rules
209        if sql.is_empty() || fixtures::matches(sql) {
210            return Ok(SqlPlan {
211                query: sql.to_owned(),
212                statement: None,
213                plan: None,
214                schema: None,
215            });
216        }
217
218        let sql = fixtures::rewrite_sql(sql);
219        let sql = sql.as_ref();
220
221        let mut stmts =
222            ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
223                .map_err(convert_err)?;
224        if stmts.len() != 1 {
225            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
226                PgErrorCode::Ec42P14,
227            ))))
228        } else {
229            let stmt = stmts.remove(0);
230
231            let describe_result = self
232                .query_handler
233                .do_describe(stmt.clone(), query_ctx)
234                .await
235                .map_err(convert_err)?;
236
237            let (plan, schema) = if let Some(DescribeResult {
238                logical_plan,
239                schema,
240            }) = describe_result
241            {
242                (Some(logical_plan), Some(schema))
243            } else {
244                (None, None)
245            };
246
247            Ok(SqlPlan {
248                query: sql.to_owned(),
249                statement: Some(stmt),
250                plan,
251                schema,
252            })
253        }
254    }
255}
256
257#[async_trait]
258impl ExtendedQueryHandler for PostgresServerHandlerInner {
259    type Statement = SqlPlan;
260    type QueryParser = DefaultQueryParser;
261
262    fn query_parser(&self) -> Arc<Self::QueryParser> {
263        self.query_parser.clone()
264    }
265
266    async fn do_query<'a, C>(
267        &self,
268        client: &mut C,
269        portal: &Portal<Self::Statement>,
270        _max_rows: usize,
271    ) -> PgWireResult<Response<'a>>
272    where
273        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
274        C::Error: Debug,
275        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
276    {
277        let query_ctx = self.session.new_query_context();
278        let db = query_ctx.get_db_string();
279        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
280            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
281            .start_timer();
282
283        let sql_plan = &portal.statement.statement;
284
285        if sql_plan.query.is_empty() {
286            // early return if query is empty
287            return Ok(Response::EmptyQuery);
288        }
289
290        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
291            send_warning_opt(client, query_ctx).await?;
292            // if the statement matches our predefined rules, return it early
293            return Ok(resps.remove(0));
294        }
295
296        let output = if let Some(plan) = &sql_plan.plan {
297            let plan = plan
298                .clone()
299                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
300                    plan, portal,
301                )?))
302                .context(DataFusionSnafu)
303                .map_err(convert_err)?;
304            self.query_handler
305                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
306                .await
307        } else {
308            // manually replace variables in prepared statement when no
309            // logical_plan is generated. This happens when logical plan is not
310            // supported for certain statements.
311            let mut sql = sql_plan.query.clone();
312            for i in 0..portal.parameter_len() {
313                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
314            }
315
316            self.query_handler
317                .do_query(&sql, query_ctx.clone())
318                .await
319                .remove(0)
320        };
321
322        send_warning_opt(client, query_ctx.clone()).await?;
323        output_to_query_response(query_ctx, output, &portal.result_column_format)
324    }
325
326    async fn do_describe_statement<C>(
327        &self,
328        _client: &mut C,
329        stmt: &StoredStatement<Self::Statement>,
330    ) -> PgWireResult<DescribeStatementResponse>
331    where
332        C: ClientInfo + Unpin + Send + Sync,
333    {
334        let sql_plan = &stmt.statement;
335        let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
336            let param_types = plan
337                .get_parameter_types()
338                .context(DataFusionSnafu)
339                .map_err(convert_err)?
340                .into_iter()
341                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
342                .collect();
343
344            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
345
346            (types, sql_plan, &Format::UnifiedBinary)
347        } else {
348            let param_types = stmt.parameter_types.clone();
349            (param_types, sql_plan, &Format::UnifiedBinary)
350        };
351
352        if let Some(schema) = &sql_plan.schema {
353            schema_to_pg(schema, format)
354                .map(|fields| DescribeStatementResponse::new(param_types, fields))
355                .map_err(convert_err)
356        } else {
357            if let Some(mut resp) =
358                fixtures::process(&sql_plan.query, self.session.new_query_context())
359            {
360                if let Response::Query(query_response) = resp.remove(0) {
361                    return Ok(DescribeStatementResponse::new(
362                        param_types,
363                        (*query_response.row_schema()).clone(),
364                    ));
365                }
366            }
367
368            Ok(DescribeStatementResponse::new(param_types, vec![]))
369        }
370    }
371
372    async fn do_describe_portal<C>(
373        &self,
374        _client: &mut C,
375        portal: &Portal<Self::Statement>,
376    ) -> PgWireResult<DescribePortalResponse>
377    where
378        C: ClientInfo + Unpin + Send + Sync,
379    {
380        let sql_plan = &portal.statement.statement;
381        let format = &portal.result_column_format;
382
383        if let Some(schema) = &sql_plan.schema {
384            schema_to_pg(schema, format)
385                .map(DescribePortalResponse::new)
386                .map_err(convert_err)
387        } else {
388            if let Some(mut resp) =
389                fixtures::process(&sql_plan.query, self.session.new_query_context())
390            {
391                if let Response::Query(query_response) = resp.remove(0) {
392                    return Ok(DescribePortalResponse::new(
393                        (*query_response.row_schema()).clone(),
394                    ));
395                }
396            }
397
398            Ok(DescribePortalResponse::new(vec![]))
399        }
400    }
401}
402
403impl ErrorHandler for PostgresServerHandlerInner {
404    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
405    where
406        C: ClientInfo,
407    {
408        debug!("Postgres interface error {}", error)
409    }
410}