servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43
44use crate::SqlPlan;
45use crate::error::{DataFusionSnafu, Result};
46use crate::postgres::types::*;
47use crate::postgres::utils::convert_err;
48use crate::postgres::{PostgresServerHandlerInner, fixtures};
49use crate::query_handler::sql::ServerSqlQueryHandlerRef;
50
51#[async_trait]
52impl SimpleQueryHandler for PostgresServerHandlerInner {
53    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
54    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
55    where
56        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
57        C::Error: Debug,
58        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
59    {
60        let query_ctx = self.session.new_query_context();
61        let db = query_ctx.get_db_string();
62        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
63            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
64            .start_timer();
65
66        if query.is_empty() {
67            // early return if query is empty
68            return Ok(vec![Response::EmptyQuery]);
69        }
70
71        let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
72            statements
73                .iter()
74                .map(|s| s.to_string())
75                .collect::<Vec<_>>()
76                .join(";")
77        } else {
78            query.to_string()
79        };
80
81        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
82            send_warning_opt(client, query_ctx).await?;
83            Ok(resps)
84        } else {
85            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
86
87            let mut results = Vec::with_capacity(outputs.len());
88
89            for output in outputs {
90                let resp =
91                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
92                results.push(resp);
93            }
94
95            send_warning_opt(client, query_ctx).await?;
96            Ok(results)
97        }
98    }
99}
100
101async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
102where
103    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
104    C::Error: Debug,
105    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
106{
107    if let Some(warning) = query_context.warning() {
108        client
109            .feed(PgWireBackendMessage::NoticeResponse(
110                ErrorInfo::new(
111                    PgErrorSeverity::Warning.to_string(),
112                    PgErrorCode::Ec01000.code(),
113                    warning.clone(),
114                )
115                .into(),
116            ))
117            .await?;
118    }
119
120    Ok(())
121}
122
123pub(crate) fn output_to_query_response(
124    query_ctx: QueryContextRef,
125    output: Result<Output>,
126    field_format: &Format,
127) -> PgWireResult<Response> {
128    match output {
129        Ok(o) => match o.data {
130            OutputData::AffectedRows(rows) => {
131                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
132            }
133            OutputData::Stream(record_stream) => {
134                let schema = record_stream.schema();
135                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
136            }
137            OutputData::RecordBatches(recordbatches) => {
138                let schema = recordbatches.schema();
139                recordbatches_to_query_response(
140                    query_ctx,
141                    recordbatches.as_stream(),
142                    schema,
143                    field_format,
144                )
145            }
146        },
147        Err(e) => Err(convert_err(e)),
148    }
149}
150
151fn recordbatches_to_query_response<S>(
152    query_ctx: QueryContextRef,
153    recordbatches_stream: S,
154    schema: SchemaRef,
155    field_format: &Format,
156) -> PgWireResult<Response>
157where
158    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
159{
160    let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
161    let pg_schema_ref = pg_schema.clone();
162    let data_row_stream = recordbatches_stream
163        .map(move |result| match result {
164            Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
165                query_ctx.clone(),
166                pg_schema_ref.clone(),
167                record_batch,
168            ))
169            .boxed(),
170            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
171        })
172        .flatten();
173
174    Ok(Response::Query(QueryResponse::new(
175        pg_schema,
176        data_row_stream,
177    )))
178}
179
180pub struct DefaultQueryParser {
181    query_handler: ServerSqlQueryHandlerRef,
182    session: Arc<Session>,
183    compatibility_parser: PostgresCompatibilityParser,
184}
185
186impl DefaultQueryParser {
187    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
188        DefaultQueryParser {
189            query_handler,
190            session,
191            compatibility_parser: PostgresCompatibilityParser::new(),
192        }
193    }
194}
195
196#[async_trait]
197impl QueryParser for DefaultQueryParser {
198    type Statement = SqlPlan;
199
200    async fn parse_sql<C>(
201        &self,
202        _client: &C,
203        sql: &str,
204        _types: &[Type],
205    ) -> PgWireResult<Self::Statement> {
206        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
207        let query_ctx = self.session.new_query_context();
208
209        // do not parse if query is empty or matches rules
210        if sql.is_empty() || fixtures::matches(sql) {
211            return Ok(SqlPlan {
212                query: sql.to_owned(),
213                statement: None,
214                plan: None,
215                schema: None,
216            });
217        }
218
219        let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
220            statements.remove(0).to_string()
221        } else {
222            // bypass the error: it can run into error because of different
223            // versions of sqlparser
224            sql.to_string()
225        };
226
227        let mut stmts = ParserContext::create_with_dialect(
228            &sql,
229            &PostgreSqlDialect {},
230            ParseOptions::default(),
231        )
232        .map_err(convert_err)?;
233        if stmts.len() != 1 {
234            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
235                PgErrorCode::Ec42P14,
236            ))))
237        } else {
238            let stmt = stmts.remove(0);
239
240            let describe_result = self
241                .query_handler
242                .do_describe(stmt.clone(), query_ctx)
243                .await
244                .map_err(convert_err)?;
245
246            let (plan, schema) = if let Some(DescribeResult {
247                logical_plan,
248                schema,
249            }) = describe_result
250            {
251                (Some(logical_plan), Some(schema))
252            } else {
253                (None, None)
254            };
255
256            Ok(SqlPlan {
257                query: sql.clone(),
258                statement: Some(stmt),
259                plan,
260                schema,
261            })
262        }
263    }
264}
265
266#[async_trait]
267impl ExtendedQueryHandler for PostgresServerHandlerInner {
268    type Statement = SqlPlan;
269    type QueryParser = DefaultQueryParser;
270
271    fn query_parser(&self) -> Arc<Self::QueryParser> {
272        self.query_parser.clone()
273    }
274
275    async fn do_query<C>(
276        &self,
277        client: &mut C,
278        portal: &Portal<Self::Statement>,
279        _max_rows: usize,
280    ) -> PgWireResult<Response>
281    where
282        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
283        C::Error: Debug,
284        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
285    {
286        let query_ctx = self.session.new_query_context();
287        let db = query_ctx.get_db_string();
288        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
289            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
290            .start_timer();
291
292        let sql_plan = &portal.statement.statement;
293
294        if sql_plan.query.is_empty() {
295            // early return if query is empty
296            return Ok(Response::EmptyQuery);
297        }
298
299        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
300            send_warning_opt(client, query_ctx).await?;
301            // if the statement matches our predefined rules, return it early
302            return Ok(resps.remove(0));
303        }
304
305        let output = if let Some(plan) = &sql_plan.plan {
306            let plan = plan
307                .clone()
308                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
309                    plan, portal,
310                )?))
311                .context(DataFusionSnafu)
312                .map_err(convert_err)?;
313            self.query_handler
314                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
315                .await
316        } else {
317            // manually replace variables in prepared statement when no
318            // logical_plan is generated. This happens when logical plan is not
319            // supported for certain statements.
320            let mut sql = sql_plan.query.clone();
321            for i in 0..portal.parameter_len() {
322                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
323            }
324
325            self.query_handler
326                .do_query(&sql, query_ctx.clone())
327                .await
328                .remove(0)
329        };
330
331        send_warning_opt(client, query_ctx.clone()).await?;
332        output_to_query_response(query_ctx, output, &portal.result_column_format)
333    }
334
335    async fn do_describe_statement<C>(
336        &self,
337        _client: &mut C,
338        stmt: &StoredStatement<Self::Statement>,
339    ) -> PgWireResult<DescribeStatementResponse>
340    where
341        C: ClientInfo + Unpin + Send + Sync,
342    {
343        let sql_plan = &stmt.statement;
344        let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
345            let param_types = plan
346                .get_parameter_types()
347                .context(DataFusionSnafu)
348                .map_err(convert_err)?
349                .into_iter()
350                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
351                .collect();
352
353            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
354
355            (types, sql_plan, &Format::UnifiedBinary)
356        } else {
357            let param_types = stmt.parameter_types.clone();
358            (param_types, sql_plan, &Format::UnifiedBinary)
359        };
360
361        if let Some(schema) = &sql_plan.schema {
362            schema_to_pg(schema, format)
363                .map(|fields| DescribeStatementResponse::new(param_types, fields))
364                .map_err(convert_err)
365        } else {
366            if let Some(mut resp) =
367                fixtures::process(&sql_plan.query, self.session.new_query_context())
368                && let Response::Query(query_response) = resp.remove(0)
369            {
370                return Ok(DescribeStatementResponse::new(
371                    param_types,
372                    (*query_response.row_schema()).clone(),
373                ));
374            }
375
376            Ok(DescribeStatementResponse::new(param_types, vec![]))
377        }
378    }
379
380    async fn do_describe_portal<C>(
381        &self,
382        _client: &mut C,
383        portal: &Portal<Self::Statement>,
384    ) -> PgWireResult<DescribePortalResponse>
385    where
386        C: ClientInfo + Unpin + Send + Sync,
387    {
388        let sql_plan = &portal.statement.statement;
389        let format = &portal.result_column_format;
390
391        if let Some(schema) = &sql_plan.schema {
392            schema_to_pg(schema, format)
393                .map(DescribePortalResponse::new)
394                .map_err(convert_err)
395        } else {
396            if let Some(mut resp) =
397                fixtures::process(&sql_plan.query, self.session.new_query_context())
398                && let Response::Query(query_response) = resp.remove(0)
399            {
400                return Ok(DescribePortalResponse::new(
401                    (*query_response.row_schema()).clone(),
402                ));
403            }
404
405            Ok(DescribePortalResponse::new(vec![]))
406        }
407    }
408}
409
410impl ErrorHandler for PostgresServerHandlerInner {
411    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
412    where
413        C: ClientInfo,
414    {
415        debug!("Postgres interface error {}", error)
416    }
417}