operator/statement/
cursor.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use common_query::{Output, OutputData};
16use common_recordbatch::cursor::RecordBatchStreamCursor;
17use common_recordbatch::RecordBatches;
18use common_telemetry::tracing;
19use query::parser::QueryStatement;
20use session::context::QueryContextRef;
21use snafu::ResultExt;
22use sql::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
23use sql::statements::statement::Statement;
24
25use crate::error::{self, Result};
26use crate::statement::StatementExecutor;
27
28impl StatementExecutor {
29    #[tracing::instrument(skip_all)]
30    pub(super) async fn declare_cursor(
31        &self,
32        declare_cursor: DeclareCursor,
33        query_ctx: QueryContextRef,
34    ) -> Result<Output> {
35        let cursor_name = declare_cursor.cursor_name.to_string();
36
37        if query_ctx.get_cursor(&cursor_name).is_some() {
38            error::CursorExistsSnafu {
39                name: cursor_name.to_string(),
40            }
41            .fail()?;
42        }
43
44        let query_stmt = Statement::Query(declare_cursor.query);
45
46        let output = self
47            .plan_exec(QueryStatement::Sql(query_stmt), query_ctx.clone())
48            .await?;
49        match output.data {
50            OutputData::RecordBatches(rb) => {
51                let rbs = rb.as_stream();
52                query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
53            }
54            OutputData::Stream(rbs) => {
55                query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
56            }
57            // Should not happen because we have query type ensured from parser.
58            OutputData::AffectedRows(_) => error::NotSupportedSnafu {
59                feat: "Non-query statement on cursor",
60            }
61            .fail()?,
62        }
63
64        Ok(Output::new_with_affected_rows(0))
65    }
66
67    #[tracing::instrument(skip_all)]
68    pub(super) async fn fetch_cursor(
69        &self,
70        fetch_cursor: FetchCursor,
71        query_ctx: QueryContextRef,
72    ) -> Result<Output> {
73        let cursor_name = fetch_cursor.cursor_name.to_string();
74        let fetch_size = fetch_cursor.fetch_size;
75        if let Some(rb) = query_ctx.get_cursor(&cursor_name) {
76            let record_batch = rb
77                .take(fetch_size as usize)
78                .await
79                .context(error::BuildRecordBatchSnafu)?;
80            let record_batches =
81                RecordBatches::try_new(record_batch.schema.clone(), vec![record_batch])
82                    .context(error::BuildRecordBatchSnafu)?;
83            Ok(Output::new_with_record_batches(record_batches))
84        } else {
85            error::CursorNotFoundSnafu { name: cursor_name }.fail()
86        }
87    }
88
89    #[tracing::instrument(skip_all)]
90    pub(super) async fn close_cursor(
91        &self,
92        close_cursor: CloseCursor,
93        query_ctx: QueryContextRef,
94    ) -> Result<Output> {
95        query_ctx.remove_cursor(&close_cursor.cursor_name.to_string());
96        Ok(Output::new_with_affected_rows(0))
97    }
98}