servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43
44use crate::SqlPlan;
45use crate::error::{DataFusionSnafu, Result};
46use crate::postgres::types::*;
47use crate::postgres::utils::convert_err;
48use crate::postgres::{PostgresServerHandlerInner, fixtures};
49use crate::query_handler::sql::ServerSqlQueryHandlerRef;
50
51#[async_trait]
52impl SimpleQueryHandler for PostgresServerHandlerInner {
53    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
54    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
55    where
56        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
57        C::Error: Debug,
58        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
59    {
60        let query_ctx = self.session.new_query_context();
61        let db = query_ctx.get_db_string();
62        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
63            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
64            .start_timer();
65
66        if query.is_empty() {
67            // early return if query is empty
68            return Ok(vec![Response::EmptyQuery]);
69        }
70
71        let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
72            statements
73                .iter()
74                .map(|s| s.to_string())
75                .collect::<Vec<_>>()
76                .join(";")
77        } else {
78            query.to_string()
79        };
80
81        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
82            send_warning_opt(client, query_ctx).await?;
83            Ok(resps)
84        } else {
85            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
86
87            let mut results = Vec::with_capacity(outputs.len());
88
89            for output in outputs {
90                let resp =
91                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
92                results.push(resp);
93            }
94
95            send_warning_opt(client, query_ctx).await?;
96            Ok(results)
97        }
98    }
99}
100
101async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
102where
103    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
104    C::Error: Debug,
105    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
106{
107    if let Some(warning) = query_context.warning() {
108        client
109            .feed(PgWireBackendMessage::NoticeResponse(
110                ErrorInfo::new(
111                    PgErrorSeverity::Warning.to_string(),
112                    PgErrorCode::Ec01000.code(),
113                    warning.clone(),
114                )
115                .into(),
116            ))
117            .await?;
118    }
119
120    Ok(())
121}
122
123pub(crate) fn output_to_query_response(
124    query_ctx: QueryContextRef,
125    output: Result<Output>,
126    field_format: &Format,
127) -> PgWireResult<Response> {
128    match output {
129        Ok(o) => match o.data {
130            OutputData::AffectedRows(rows) => {
131                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
132            }
133            OutputData::Stream(record_stream) => {
134                let schema = record_stream.schema();
135                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
136            }
137            OutputData::RecordBatches(recordbatches) => {
138                let schema = recordbatches.schema();
139                recordbatches_to_query_response(
140                    query_ctx,
141                    recordbatches.as_stream(),
142                    schema,
143                    field_format,
144                )
145            }
146        },
147        Err(e) => Err(convert_err(e)),
148    }
149}
150
151fn recordbatches_to_query_response<S>(
152    query_ctx: QueryContextRef,
153    recordbatches_stream: S,
154    schema: SchemaRef,
155    field_format: &Format,
156) -> PgWireResult<Response>
157where
158    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
159{
160    let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
161    let pg_schema_ref = pg_schema.clone();
162    let data_row_stream = recordbatches_stream
163        .map(|record_batch_result| match record_batch_result {
164            Ok(rb) => stream::iter(
165                // collect rows from a single recordbatch into vector to avoid
166                // borrowing it
167                rb.rows().map(Ok).collect::<Vec<_>>(),
168            )
169            .boxed(),
170            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
171        })
172        .flatten() // flatten into stream<result<row>>
173        .map(move |row| {
174            row.and_then(|row| {
175                let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
176                for (value, column) in row.iter().zip(schema.column_schemas()) {
177                    encode_value(&query_ctx, value, &mut encoder, &column.data_type)?;
178                }
179                encoder.finish()
180            })
181        });
182
183    Ok(Response::Query(QueryResponse::new(
184        pg_schema,
185        data_row_stream,
186    )))
187}
188
189pub struct DefaultQueryParser {
190    query_handler: ServerSqlQueryHandlerRef,
191    session: Arc<Session>,
192    compatibility_parser: PostgresCompatibilityParser,
193}
194
195impl DefaultQueryParser {
196    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
197        DefaultQueryParser {
198            query_handler,
199            session,
200            compatibility_parser: PostgresCompatibilityParser::new(),
201        }
202    }
203}
204
205#[async_trait]
206impl QueryParser for DefaultQueryParser {
207    type Statement = SqlPlan;
208
209    async fn parse_sql<C>(
210        &self,
211        _client: &C,
212        sql: &str,
213        _types: &[Type],
214    ) -> PgWireResult<Self::Statement> {
215        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
216        let query_ctx = self.session.new_query_context();
217
218        // do not parse if query is empty or matches rules
219        if sql.is_empty() || fixtures::matches(sql) {
220            return Ok(SqlPlan {
221                query: sql.to_owned(),
222                statement: None,
223                plan: None,
224                schema: None,
225            });
226        }
227
228        let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
229            statements.remove(0).to_string()
230        } else {
231            // bypass the error: it can run into error because of different
232            // versions of sqlparser
233            sql.to_string()
234        };
235
236        let mut stmts = ParserContext::create_with_dialect(
237            &sql,
238            &PostgreSqlDialect {},
239            ParseOptions::default(),
240        )
241        .map_err(convert_err)?;
242        if stmts.len() != 1 {
243            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
244                PgErrorCode::Ec42P14,
245            ))))
246        } else {
247            let stmt = stmts.remove(0);
248
249            let describe_result = self
250                .query_handler
251                .do_describe(stmt.clone(), query_ctx)
252                .await
253                .map_err(convert_err)?;
254
255            let (plan, schema) = if let Some(DescribeResult {
256                logical_plan,
257                schema,
258            }) = describe_result
259            {
260                (Some(logical_plan), Some(schema))
261            } else {
262                (None, None)
263            };
264
265            Ok(SqlPlan {
266                query: sql.clone(),
267                statement: Some(stmt),
268                plan,
269                schema,
270            })
271        }
272    }
273}
274
275#[async_trait]
276impl ExtendedQueryHandler for PostgresServerHandlerInner {
277    type Statement = SqlPlan;
278    type QueryParser = DefaultQueryParser;
279
280    fn query_parser(&self) -> Arc<Self::QueryParser> {
281        self.query_parser.clone()
282    }
283
284    async fn do_query<C>(
285        &self,
286        client: &mut C,
287        portal: &Portal<Self::Statement>,
288        _max_rows: usize,
289    ) -> PgWireResult<Response>
290    where
291        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
292        C::Error: Debug,
293        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
294    {
295        let query_ctx = self.session.new_query_context();
296        let db = query_ctx.get_db_string();
297        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
298            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
299            .start_timer();
300
301        let sql_plan = &portal.statement.statement;
302
303        if sql_plan.query.is_empty() {
304            // early return if query is empty
305            return Ok(Response::EmptyQuery);
306        }
307
308        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
309            send_warning_opt(client, query_ctx).await?;
310            // if the statement matches our predefined rules, return it early
311            return Ok(resps.remove(0));
312        }
313
314        let output = if let Some(plan) = &sql_plan.plan {
315            let plan = plan
316                .clone()
317                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
318                    plan, portal,
319                )?))
320                .context(DataFusionSnafu)
321                .map_err(convert_err)?;
322            self.query_handler
323                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
324                .await
325        } else {
326            // manually replace variables in prepared statement when no
327            // logical_plan is generated. This happens when logical plan is not
328            // supported for certain statements.
329            let mut sql = sql_plan.query.clone();
330            for i in 0..portal.parameter_len() {
331                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
332            }
333
334            self.query_handler
335                .do_query(&sql, query_ctx.clone())
336                .await
337                .remove(0)
338        };
339
340        send_warning_opt(client, query_ctx.clone()).await?;
341        output_to_query_response(query_ctx, output, &portal.result_column_format)
342    }
343
344    async fn do_describe_statement<C>(
345        &self,
346        _client: &mut C,
347        stmt: &StoredStatement<Self::Statement>,
348    ) -> PgWireResult<DescribeStatementResponse>
349    where
350        C: ClientInfo + Unpin + Send + Sync,
351    {
352        let sql_plan = &stmt.statement;
353        let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
354            let param_types = plan
355                .get_parameter_types()
356                .context(DataFusionSnafu)
357                .map_err(convert_err)?
358                .into_iter()
359                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
360                .collect();
361
362            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
363
364            (types, sql_plan, &Format::UnifiedBinary)
365        } else {
366            let param_types = stmt.parameter_types.clone();
367            (param_types, sql_plan, &Format::UnifiedBinary)
368        };
369
370        if let Some(schema) = &sql_plan.schema {
371            schema_to_pg(schema, format)
372                .map(|fields| DescribeStatementResponse::new(param_types, fields))
373                .map_err(convert_err)
374        } else {
375            if let Some(mut resp) =
376                fixtures::process(&sql_plan.query, self.session.new_query_context())
377                && let Response::Query(query_response) = resp.remove(0)
378            {
379                return Ok(DescribeStatementResponse::new(
380                    param_types,
381                    (*query_response.row_schema()).clone(),
382                ));
383            }
384
385            Ok(DescribeStatementResponse::new(param_types, vec![]))
386        }
387    }
388
389    async fn do_describe_portal<C>(
390        &self,
391        _client: &mut C,
392        portal: &Portal<Self::Statement>,
393    ) -> PgWireResult<DescribePortalResponse>
394    where
395        C: ClientInfo + Unpin + Send + Sync,
396    {
397        let sql_plan = &portal.statement.statement;
398        let format = &portal.result_column_format;
399
400        if let Some(schema) = &sql_plan.schema {
401            schema_to_pg(schema, format)
402                .map(DescribePortalResponse::new)
403                .map_err(convert_err)
404        } else {
405            if let Some(mut resp) =
406                fixtures::process(&sql_plan.query, self.session.new_query_context())
407                && let Response::Query(query_response) = resp.remove(0)
408            {
409                return Ok(DescribePortalResponse::new(
410                    (*query_response.row_schema()).clone(),
411                ));
412            }
413
414            Ok(DescribePortalResponse::new(vec![]))
415        }
416    }
417}
418
419impl ErrorHandler for PostgresServerHandlerInner {
420    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
421    where
422        C: ClientInfo,
423    {
424        debug!("Postgres interface error {}", error)
425    }
426}