servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43use sql::statements::statement::Statement;
44
45use crate::SqlPlan;
46use crate::error::{DataFusionSnafu, Result};
47use crate::postgres::types::*;
48use crate::postgres::utils::convert_err;
49use crate::postgres::{PostgresServerHandlerInner, fixtures};
50use crate::query_handler::sql::ServerSqlQueryHandlerRef;
51
52#[async_trait]
53impl SimpleQueryHandler for PostgresServerHandlerInner {
54    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
55    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
56    where
57        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
58        C::Error: Debug,
59        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
60    {
61        let query_ctx = self.session.new_query_context();
62        let db = query_ctx.get_db_string();
63        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
64            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
65            .start_timer();
66
67        if query.is_empty() {
68            // early return if query is empty
69            return Ok(vec![Response::EmptyQuery]);
70        }
71
72        let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
73            statements
74                .iter()
75                .map(|s| s.to_string())
76                .collect::<Vec<_>>()
77                .join(";")
78        } else {
79            query.to_string()
80        };
81
82        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
83            send_warning_opt(client, query_ctx).await?;
84            Ok(resps)
85        } else {
86            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
87
88            let mut results = Vec::with_capacity(outputs.len());
89
90            for output in outputs {
91                let resp =
92                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
93                results.push(resp);
94            }
95
96            send_warning_opt(client, query_ctx).await?;
97            Ok(results)
98        }
99    }
100}
101
102async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
103where
104    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
105    C::Error: Debug,
106    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
107{
108    if let Some(warning) = query_context.warning() {
109        client
110            .feed(PgWireBackendMessage::NoticeResponse(
111                ErrorInfo::new(
112                    PgErrorSeverity::Warning.to_string(),
113                    PgErrorCode::Ec01000.code(),
114                    warning.clone(),
115                )
116                .into(),
117            ))
118            .await?;
119    }
120
121    Ok(())
122}
123
124pub(crate) fn output_to_query_response(
125    query_ctx: QueryContextRef,
126    output: Result<Output>,
127    field_format: &Format,
128) -> PgWireResult<Response> {
129    match output {
130        Ok(o) => match o.data {
131            OutputData::AffectedRows(rows) => {
132                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
133            }
134            OutputData::Stream(record_stream) => {
135                let schema = record_stream.schema();
136                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
137            }
138            OutputData::RecordBatches(recordbatches) => {
139                let schema = recordbatches.schema();
140                recordbatches_to_query_response(
141                    query_ctx,
142                    recordbatches.as_stream(),
143                    schema,
144                    field_format,
145                )
146            }
147        },
148        Err(e) => Err(convert_err(e)),
149    }
150}
151
152fn recordbatches_to_query_response<S>(
153    query_ctx: QueryContextRef,
154    recordbatches_stream: S,
155    schema: SchemaRef,
156    field_format: &Format,
157) -> PgWireResult<Response>
158where
159    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
160{
161    let format_options = format_options_from_query_ctx(&query_ctx);
162    let pg_schema = Arc::new(
163        schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
164    );
165    let pg_schema_ref = pg_schema.clone();
166    let data_row_stream = recordbatches_stream
167        .map(move |result| match result {
168            Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
169                query_ctx.clone(),
170                pg_schema_ref.clone(),
171                record_batch,
172            ))
173            .boxed(),
174            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
175        })
176        .flatten();
177
178    Ok(Response::Query(QueryResponse::new(
179        pg_schema,
180        data_row_stream,
181    )))
182}
183
184pub struct DefaultQueryParser {
185    query_handler: ServerSqlQueryHandlerRef,
186    session: Arc<Session>,
187    compatibility_parser: PostgresCompatibilityParser,
188}
189
190impl DefaultQueryParser {
191    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
192        DefaultQueryParser {
193            query_handler,
194            session,
195            compatibility_parser: PostgresCompatibilityParser::new(),
196        }
197    }
198}
199
200#[async_trait]
201impl QueryParser for DefaultQueryParser {
202    type Statement = SqlPlan;
203
204    async fn parse_sql<C>(
205        &self,
206        _client: &C,
207        sql: &str,
208        _types: &[Option<Type>],
209    ) -> PgWireResult<Self::Statement> {
210        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
211        let query_ctx = self.session.new_query_context();
212
213        // do not parse if query is empty or matches rules
214        if sql.is_empty() || fixtures::matches(sql) {
215            return Ok(SqlPlan {
216                query: sql.to_owned(),
217                statement: None,
218                plan: None,
219                schema: None,
220            });
221        }
222
223        let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
224            statements.remove(0).to_string()
225        } else {
226            // bypass the error: it can run into error because of different
227            // versions of sqlparser
228            sql.to_string()
229        };
230
231        let mut stmts = ParserContext::create_with_dialect(
232            &sql,
233            &PostgreSqlDialect {},
234            ParseOptions::default(),
235        )
236        .map_err(convert_err)?;
237        if stmts.len() != 1 {
238            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
239                PgErrorCode::Ec42P14,
240            ))))
241        } else {
242            let stmt = stmts.remove(0);
243
244            let describe_result = self
245                .query_handler
246                .do_describe(stmt.clone(), query_ctx)
247                .await
248                .map_err(convert_err)?;
249
250            let (plan, schema) = if let Some(DescribeResult {
251                logical_plan,
252                schema,
253            }) = describe_result
254            {
255                (Some(logical_plan), Some(schema))
256            } else {
257                (None, None)
258            };
259
260            Ok(SqlPlan {
261                query: sql.clone(),
262                statement: Some(stmt),
263                plan,
264                schema,
265            })
266        }
267    }
268
269    fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
270        // we have our own implementation of describes in ExtendedQueryHandler
271        // so we don't use these methods
272        Err(PgWireError::ApiError(
273            "get_parameter_types is not expected to be called".into(),
274        ))
275    }
276
277    fn get_result_schema(
278        &self,
279        _stmt: &Self::Statement,
280        _column_format: Option<&Format>,
281    ) -> PgWireResult<Vec<FieldInfo>> {
282        // we have our own implementation of describes in ExtendedQueryHandler
283        // so we don't use these methods
284        Err(PgWireError::ApiError(
285            "get_result_schema is not expected to be called".into(),
286        ))
287    }
288}
289
290#[async_trait]
291impl ExtendedQueryHandler for PostgresServerHandlerInner {
292    type Statement = SqlPlan;
293    type QueryParser = DefaultQueryParser;
294
295    fn query_parser(&self) -> Arc<Self::QueryParser> {
296        self.query_parser.clone()
297    }
298
299    async fn do_query<C>(
300        &self,
301        client: &mut C,
302        portal: &Portal<Self::Statement>,
303        _max_rows: usize,
304    ) -> PgWireResult<Response>
305    where
306        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
307        C::Error: Debug,
308        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
309    {
310        let query_ctx = self.session.new_query_context();
311        let db = query_ctx.get_db_string();
312        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
313            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
314            .start_timer();
315
316        let sql_plan = &portal.statement.statement;
317
318        if sql_plan.query.is_empty() {
319            // early return if query is empty
320            return Ok(Response::EmptyQuery);
321        }
322
323        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
324            send_warning_opt(client, query_ctx).await?;
325            // if the statement matches our predefined rules, return it early
326            return Ok(resps.remove(0));
327        }
328
329        let output = if let Some(plan) = &sql_plan.plan {
330            let values = parameters_to_scalar_values(plan, portal)?;
331            let plan = plan
332                .clone()
333                .replace_params_with_values(&ParamValues::List(
334                    values.into_iter().map(Into::into).collect(),
335                ))
336                .context(DataFusionSnafu)
337                .map_err(convert_err)?;
338            self.query_handler
339                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
340                .await
341        } else {
342            // manually replace variables in prepared statement when no
343            // logical_plan is generated. This happens when logical plan is not
344            // supported for certain statements.
345            let mut sql = sql_plan.query.clone();
346            for i in 0..portal.parameter_len() {
347                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
348            }
349
350            self.query_handler
351                .do_query(&sql, query_ctx.clone())
352                .await
353                .remove(0)
354        };
355
356        send_warning_opt(client, query_ctx.clone()).await?;
357        output_to_query_response(query_ctx, output, &portal.result_column_format)
358    }
359
360    async fn do_describe_statement<C>(
361        &self,
362        _client: &mut C,
363        stmt: &StoredStatement<Self::Statement>,
364    ) -> PgWireResult<DescribeStatementResponse>
365    where
366        C: ClientInfo + Unpin + Send + Sync,
367    {
368        let sql_plan = &stmt.statement;
369        // client provided parameter types, can be empty if client doesn't try to parse statement
370        let provided_param_types = &stmt.parameter_types;
371        let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
372            let param_types = plan
373                .get_parameter_types()
374                .context(DataFusionSnafu)
375                .map_err(convert_err)?
376                .into_iter()
377                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
378                .collect();
379
380            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
381
382            Some(types)
383        } else {
384            None
385        };
386
387        let param_count = if provided_param_types.is_empty() {
388            server_inferenced_types
389                .as_ref()
390                .map(|types| types.len())
391                .unwrap_or(0)
392        } else {
393            provided_param_types.len()
394        };
395
396        let param_types = (0..param_count)
397            .map(|i| {
398                let client_type = provided_param_types.get(i);
399                // use server type when client provided type is None (oid: 0 or other invalid values)
400                match client_type {
401                    Some(Some(client_type)) => client_type.clone(),
402                    _ => server_inferenced_types
403                        .as_ref()
404                        .and_then(|types| types.get(i).cloned())
405                        .unwrap_or(Type::UNKNOWN),
406                }
407            })
408            .collect::<Vec<_>>();
409
410        if let Some(schema) = &sql_plan.schema {
411            schema_to_pg(schema, &Format::UnifiedBinary, None)
412                .map(|fields| DescribeStatementResponse::new(param_types, fields))
413                .map_err(convert_err)
414        } else {
415            if let Some(mut resp) =
416                fixtures::process(&sql_plan.query, self.session.new_query_context())
417                && let Response::Query(query_response) = resp.remove(0)
418            {
419                return Ok(DescribeStatementResponse::new(
420                    param_types,
421                    (*query_response.row_schema()).clone(),
422                ));
423            }
424
425            Ok(DescribeStatementResponse::new(param_types, vec![]))
426        }
427    }
428
429    async fn do_describe_portal<C>(
430        &self,
431        _client: &mut C,
432        portal: &Portal<Self::Statement>,
433    ) -> PgWireResult<DescribePortalResponse>
434    where
435        C: ClientInfo + Unpin + Send + Sync,
436    {
437        let sql_plan = &portal.statement.statement;
438        let format = &portal.result_column_format;
439
440        match sql_plan.statement.as_ref() {
441            Some(Statement::Query(_)) => {
442                // if the query has a schema, it is managed by datafusion, use the schema
443                if let Some(schema) = &sql_plan.schema {
444                    schema_to_pg(schema, format, None)
445                        .map(DescribePortalResponse::new)
446                        .map_err(convert_err)
447                } else {
448                    // fallback to NoData
449                    Ok(DescribePortalResponse::new(vec![]))
450                }
451            }
452            // We can cover only part of show statements
453            // these show create statements will return 2 columns
454            Some(Statement::ShowCreateDatabase(_))
455            | Some(Statement::ShowCreateTable(_))
456            | Some(Statement::ShowCreateFlow(_))
457            | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
458                FieldInfo::new(
459                    "name".to_string(),
460                    None,
461                    None,
462                    Type::TEXT,
463                    format.format_for(0),
464                ),
465                FieldInfo::new(
466                    "create_statement".to_string(),
467                    None,
468                    None,
469                    Type::TEXT,
470                    format.format_for(1),
471                ),
472            ])),
473            // single column show statements
474            Some(Statement::ShowTables(_))
475            | Some(Statement::ShowFlows(_))
476            | Some(Statement::ShowViews(_)) => {
477                Ok(DescribePortalResponse::new(vec![FieldInfo::new(
478                    "name".to_string(),
479                    None,
480                    None,
481                    Type::TEXT,
482                    format.format_for(0),
483                )]))
484            }
485            // we will not support other show statements for extended query protocol at least for now.
486            // because the return columns is not predictable at this stage
487            _ => {
488                // test if query caught by fixture
489                if let Some(mut resp) =
490                    fixtures::process(&sql_plan.query, self.session.new_query_context())
491                    && let Response::Query(query_response) = resp.remove(0)
492                {
493                    Ok(DescribePortalResponse::new(
494                        (*query_response.row_schema()).clone(),
495                    ))
496                } else {
497                    // fallback to NoData
498                    Ok(DescribePortalResponse::new(vec![]))
499                }
500            }
501        }
502    }
503}
504
505impl ErrorHandler for PostgresServerHandlerInner {
506    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
507    where
508        C: ClientInfo,
509    {
510        debug!("Postgres interface error {}", error)
511    }
512}