servers/postgres/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use common_query::{Output, OutputData};
20use common_recordbatch::RecordBatch;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_telemetry::{debug, tracing};
23use datafusion_common::ParamValues;
24use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
25use datatypes::prelude::ConcreteDataType;
26use datatypes::schema::SchemaRef;
27use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
28use pgwire::api::portal::{Format, Portal};
29use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
30use pgwire::api::results::{
31    DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
32};
33use pgwire::api::stmt::{QueryParser, StoredStatement};
34use pgwire::api::{ClientInfo, ErrorHandler, Type};
35use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
36use pgwire::messages::PgWireBackendMessage;
37use query::query_engine::DescribeResult;
38use session::Session;
39use session::context::QueryContextRef;
40use snafu::ResultExt;
41use sql::dialect::PostgreSqlDialect;
42use sql::parser::{ParseOptions, ParserContext};
43use sql::statements::statement::Statement;
44
45use crate::SqlPlan;
46use crate::error::{DataFusionSnafu, Result};
47use crate::postgres::types::*;
48use crate::postgres::utils::convert_err;
49use crate::postgres::{PostgresServerHandlerInner, fixtures};
50use crate::query_handler::sql::ServerSqlQueryHandlerRef;
51
52#[async_trait]
53impl SimpleQueryHandler for PostgresServerHandlerInner {
54    #[tracing::instrument(skip_all, fields(protocol = "postgres"))]
55    async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
56    where
57        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
58        C::Error: Debug,
59        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
60    {
61        let query_ctx = self.session.new_query_context();
62        let db = query_ctx.get_db_string();
63        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
64            .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
65            .start_timer();
66
67        if query.is_empty() {
68            // early return if query is empty
69            return Ok(vec![Response::EmptyQuery]);
70        }
71
72        let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
73            statements
74                .iter()
75                .map(|s| s.to_string())
76                .collect::<Vec<_>>()
77                .join(";")
78        } else {
79            query.to_string()
80        };
81
82        if let Some(resps) = fixtures::process(&query, query_ctx.clone()) {
83            send_warning_opt(client, query_ctx).await?;
84            Ok(resps)
85        } else {
86            let outputs = self.query_handler.do_query(&query, query_ctx.clone()).await;
87
88            let mut results = Vec::with_capacity(outputs.len());
89
90            for output in outputs {
91                let resp =
92                    output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
93                results.push(resp);
94            }
95
96            send_warning_opt(client, query_ctx).await?;
97            Ok(results)
98        }
99    }
100}
101
102async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
103where
104    C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
105    C::Error: Debug,
106    PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
107{
108    if let Some(warning) = query_context.warning() {
109        client
110            .feed(PgWireBackendMessage::NoticeResponse(
111                ErrorInfo::new(
112                    PgErrorSeverity::Warning.to_string(),
113                    PgErrorCode::Ec01000.code(),
114                    warning.clone(),
115                )
116                .into(),
117            ))
118            .await?;
119    }
120
121    Ok(())
122}
123
124pub(crate) fn output_to_query_response(
125    query_ctx: QueryContextRef,
126    output: Result<Output>,
127    field_format: &Format,
128) -> PgWireResult<Response> {
129    match output {
130        Ok(o) => match o.data {
131            OutputData::AffectedRows(rows) => {
132                Ok(Response::Execution(Tag::new("OK").with_rows(rows)))
133            }
134            OutputData::Stream(record_stream) => {
135                let schema = record_stream.schema();
136                recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
137            }
138            OutputData::RecordBatches(recordbatches) => {
139                let schema = recordbatches.schema();
140                recordbatches_to_query_response(
141                    query_ctx,
142                    recordbatches.as_stream(),
143                    schema,
144                    field_format,
145                )
146            }
147        },
148        Err(e) => Err(convert_err(e)),
149    }
150}
151
152fn recordbatches_to_query_response<S>(
153    query_ctx: QueryContextRef,
154    recordbatches_stream: S,
155    schema: SchemaRef,
156    field_format: &Format,
157) -> PgWireResult<Response>
158where
159    S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
160{
161    let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
162    let pg_schema_ref = pg_schema.clone();
163    let data_row_stream = recordbatches_stream
164        .map(move |result| match result {
165            Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
166                query_ctx.clone(),
167                pg_schema_ref.clone(),
168                record_batch,
169            ))
170            .boxed(),
171            Err(e) => stream::once(future::err(convert_err(e))).boxed(),
172        })
173        .flatten();
174
175    Ok(Response::Query(QueryResponse::new(
176        pg_schema,
177        data_row_stream,
178    )))
179}
180
181pub struct DefaultQueryParser {
182    query_handler: ServerSqlQueryHandlerRef,
183    session: Arc<Session>,
184    compatibility_parser: PostgresCompatibilityParser,
185}
186
187impl DefaultQueryParser {
188    pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
189        DefaultQueryParser {
190            query_handler,
191            session,
192            compatibility_parser: PostgresCompatibilityParser::new(),
193        }
194    }
195}
196
197#[async_trait]
198impl QueryParser for DefaultQueryParser {
199    type Statement = SqlPlan;
200
201    async fn parse_sql<C>(
202        &self,
203        _client: &C,
204        sql: &str,
205        _types: &[Option<Type>],
206    ) -> PgWireResult<Self::Statement> {
207        crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
208        let query_ctx = self.session.new_query_context();
209
210        // do not parse if query is empty or matches rules
211        if sql.is_empty() || fixtures::matches(sql) {
212            return Ok(SqlPlan {
213                query: sql.to_owned(),
214                statement: None,
215                plan: None,
216                schema: None,
217            });
218        }
219
220        let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
221            statements.remove(0).to_string()
222        } else {
223            // bypass the error: it can run into error because of different
224            // versions of sqlparser
225            sql.to_string()
226        };
227
228        let mut stmts = ParserContext::create_with_dialect(
229            &sql,
230            &PostgreSqlDialect {},
231            ParseOptions::default(),
232        )
233        .map_err(convert_err)?;
234        if stmts.len() != 1 {
235            Err(PgWireError::UserError(Box::new(ErrorInfo::from(
236                PgErrorCode::Ec42P14,
237            ))))
238        } else {
239            let stmt = stmts.remove(0);
240
241            let describe_result = self
242                .query_handler
243                .do_describe(stmt.clone(), query_ctx)
244                .await
245                .map_err(convert_err)?;
246
247            let (plan, schema) = if let Some(DescribeResult {
248                logical_plan,
249                schema,
250            }) = describe_result
251            {
252                (Some(logical_plan), Some(schema))
253            } else {
254                (None, None)
255            };
256
257            Ok(SqlPlan {
258                query: sql.clone(),
259                statement: Some(stmt),
260                plan,
261                schema,
262            })
263        }
264    }
265
266    fn get_parameter_types(&self, _stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
267        // we have our own implementation of describes in ExtendedQueryHandler
268        // so we don't use these methods
269        Err(PgWireError::ApiError(
270            "get_parameter_types is not expected to be called".into(),
271        ))
272    }
273
274    fn get_result_schema(
275        &self,
276        _stmt: &Self::Statement,
277        _column_format: Option<&Format>,
278    ) -> PgWireResult<Vec<FieldInfo>> {
279        // we have our own implementation of describes in ExtendedQueryHandler
280        // so we don't use these methods
281        Err(PgWireError::ApiError(
282            "get_result_schema is not expected to be called".into(),
283        ))
284    }
285}
286
287#[async_trait]
288impl ExtendedQueryHandler for PostgresServerHandlerInner {
289    type Statement = SqlPlan;
290    type QueryParser = DefaultQueryParser;
291
292    fn query_parser(&self) -> Arc<Self::QueryParser> {
293        self.query_parser.clone()
294    }
295
296    async fn do_query<C>(
297        &self,
298        client: &mut C,
299        portal: &Portal<Self::Statement>,
300        _max_rows: usize,
301    ) -> PgWireResult<Response>
302    where
303        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
304        C::Error: Debug,
305        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
306    {
307        let query_ctx = self.session.new_query_context();
308        let db = query_ctx.get_db_string();
309        let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER
310            .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
311            .start_timer();
312
313        let sql_plan = &portal.statement.statement;
314
315        if sql_plan.query.is_empty() {
316            // early return if query is empty
317            return Ok(Response::EmptyQuery);
318        }
319
320        if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
321            send_warning_opt(client, query_ctx).await?;
322            // if the statement matches our predefined rules, return it early
323            return Ok(resps.remove(0));
324        }
325
326        let output = if let Some(plan) = &sql_plan.plan {
327            let plan = plan
328                .clone()
329                .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values(
330                    plan, portal,
331                )?))
332                .context(DataFusionSnafu)
333                .map_err(convert_err)?;
334            self.query_handler
335                .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
336                .await
337        } else {
338            // manually replace variables in prepared statement when no
339            // logical_plan is generated. This happens when logical plan is not
340            // supported for certain statements.
341            let mut sql = sql_plan.query.clone();
342            for i in 0..portal.parameter_len() {
343                sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
344            }
345
346            self.query_handler
347                .do_query(&sql, query_ctx.clone())
348                .await
349                .remove(0)
350        };
351
352        send_warning_opt(client, query_ctx.clone()).await?;
353        output_to_query_response(query_ctx, output, &portal.result_column_format)
354    }
355
356    async fn do_describe_statement<C>(
357        &self,
358        _client: &mut C,
359        stmt: &StoredStatement<Self::Statement>,
360    ) -> PgWireResult<DescribeStatementResponse>
361    where
362        C: ClientInfo + Unpin + Send + Sync,
363    {
364        let sql_plan = &stmt.statement;
365        // client provided parameter types, can be empty if client doesn't try to parse statement
366        let provided_param_types = &stmt.parameter_types;
367        let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
368            let param_types = plan
369                .get_parameter_types()
370                .context(DataFusionSnafu)
371                .map_err(convert_err)?
372                .into_iter()
373                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
374                .collect();
375
376            let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
377
378            Some(types)
379        } else {
380            None
381        };
382
383        let param_count = if provided_param_types.is_empty() {
384            server_inferenced_types
385                .as_ref()
386                .map(|types| types.len())
387                .unwrap_or(0)
388        } else {
389            provided_param_types.len()
390        };
391
392        let param_types = (0..param_count)
393            .map(|i| {
394                let client_type = provided_param_types.get(i);
395                // use server type when client provided type is None (oid: 0 or other invalid values)
396                match client_type {
397                    Some(Some(client_type)) => client_type.clone(),
398                    _ => server_inferenced_types
399                        .as_ref()
400                        .and_then(|types| types.get(i).cloned())
401                        .unwrap_or(Type::UNKNOWN),
402                }
403            })
404            .collect::<Vec<_>>();
405
406        if let Some(schema) = &sql_plan.schema {
407            schema_to_pg(schema, &Format::UnifiedBinary)
408                .map(|fields| DescribeStatementResponse::new(param_types, fields))
409                .map_err(convert_err)
410        } else {
411            if let Some(mut resp) =
412                fixtures::process(&sql_plan.query, self.session.new_query_context())
413                && let Response::Query(query_response) = resp.remove(0)
414            {
415                return Ok(DescribeStatementResponse::new(
416                    param_types,
417                    (*query_response.row_schema()).clone(),
418                ));
419            }
420
421            Ok(DescribeStatementResponse::new(param_types, vec![]))
422        }
423    }
424
425    async fn do_describe_portal<C>(
426        &self,
427        _client: &mut C,
428        portal: &Portal<Self::Statement>,
429    ) -> PgWireResult<DescribePortalResponse>
430    where
431        C: ClientInfo + Unpin + Send + Sync,
432    {
433        let sql_plan = &portal.statement.statement;
434        let format = &portal.result_column_format;
435
436        match sql_plan.statement.as_ref() {
437            Some(Statement::Query(_)) => {
438                // if the query has a schema, it is managed by datafusion, use the schema
439                if let Some(schema) = &sql_plan.schema {
440                    schema_to_pg(schema, format)
441                        .map(DescribePortalResponse::new)
442                        .map_err(convert_err)
443                } else {
444                    // fallback to NoData
445                    Ok(DescribePortalResponse::new(vec![]))
446                }
447            }
448            // We can cover only part of show statements
449            // these show create statements will return 2 columns
450            Some(Statement::ShowCreateDatabase(_))
451            | Some(Statement::ShowCreateTable(_))
452            | Some(Statement::ShowCreateFlow(_))
453            | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![
454                FieldInfo::new(
455                    "name".to_string(),
456                    None,
457                    None,
458                    Type::TEXT,
459                    format.format_for(0),
460                ),
461                FieldInfo::new(
462                    "create_statement".to_string(),
463                    None,
464                    None,
465                    Type::TEXT,
466                    format.format_for(1),
467                ),
468            ])),
469            // single column show statements
470            Some(Statement::ShowTables(_))
471            | Some(Statement::ShowFlows(_))
472            | Some(Statement::ShowViews(_)) => {
473                Ok(DescribePortalResponse::new(vec![FieldInfo::new(
474                    "name".to_string(),
475                    None,
476                    None,
477                    Type::TEXT,
478                    format.format_for(0),
479                )]))
480            }
481            // we will not support other show statements for extended query protocol at least for now.
482            // because the return columns is not predictable at this stage
483            _ => {
484                // test if query caught by fixture
485                if let Some(mut resp) =
486                    fixtures::process(&sql_plan.query, self.session.new_query_context())
487                    && let Response::Query(query_response) = resp.remove(0)
488                {
489                    Ok(DescribePortalResponse::new(
490                        (*query_response.row_schema()).clone(),
491                    ))
492                } else {
493                    // fallback to NoData
494                    Ok(DescribePortalResponse::new(vec![]))
495                }
496            }
497        }
498    }
499}
500
501impl ErrorHandler for PostgresServerHandlerInner {
502    fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
503    where
504        C: ClientInfo,
505    {
506        debug!("Postgres interface error {}", error)
507    }
508}