session/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod context;
16pub mod hints;
17pub mod protocol_ctx;
18pub mod session_config;
19pub mod table_name;
20
21use std::collections::{HashMap, VecDeque};
22use std::net::SocketAddr;
23use std::sync::{Arc, RwLock};
24use std::time::Duration;
25
26use auth::UserInfoRef;
27use common_catalog::build_db_string;
28use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
29use common_recordbatch::cursor::RecordBatchStreamCursor;
30pub use common_session::ReadPreference;
31use common_time::Timezone;
32use common_time::timezone::get_timezone;
33use context::{ConfigurationVariables, QueryContextBuilder};
34use derive_more::Debug;
35
36use crate::context::{Channel, ConnInfo, QueryContextRef};
37
38/// Maximum number of warnings to store per session (similar to MySQL's max_error_count)
39const MAX_WARNINGS: usize = 64;
40
41/// Session for persistent connection such as MySQL, PostgreSQL etc.
42#[derive(Debug)]
43pub struct Session {
44    catalog: RwLock<String>,
45    mutable_inner: Arc<RwLock<MutableInner>>,
46    conn_info: ConnInfo,
47    configuration_variables: Arc<ConfigurationVariables>,
48    // the process id to use when killing the query
49    process_id: u32,
50}
51
52pub type SessionRef = Arc<Session>;
53
54/// A container for mutable items in query context
55#[derive(Debug)]
56pub(crate) struct MutableInner {
57    schema: String,
58    user_info: UserInfoRef,
59    timezone: Timezone,
60    query_timeout: Option<Duration>,
61    read_preference: ReadPreference,
62    #[debug(skip)]
63    pub(crate) cursors: HashMap<String, Arc<RecordBatchStreamCursor>>,
64    /// Warning messages for MySQL SHOW WARNINGS support
65    warnings: VecDeque<String>,
66}
67
68impl Default for MutableInner {
69    fn default() -> Self {
70        Self {
71            schema: DEFAULT_SCHEMA_NAME.into(),
72            user_info: auth::userinfo_by_name(None),
73            timezone: get_timezone(None).clone(),
74            query_timeout: None,
75            read_preference: ReadPreference::Leader,
76            cursors: HashMap::with_capacity(0),
77            warnings: VecDeque::new(),
78        }
79    }
80}
81
82impl Session {
83    pub fn new(
84        addr: Option<SocketAddr>,
85        channel: Channel,
86        configuration_variables: ConfigurationVariables,
87        process_id: u32,
88    ) -> Self {
89        Session {
90            catalog: RwLock::new(DEFAULT_CATALOG_NAME.into()),
91            conn_info: ConnInfo::new(addr, channel),
92            configuration_variables: Arc::new(configuration_variables),
93            mutable_inner: Arc::new(RwLock::new(MutableInner::default())),
94            process_id,
95        }
96    }
97
98    pub fn new_query_context(&self) -> QueryContextRef {
99        QueryContextBuilder::default()
100            // catalog is not allowed for update in query context so we use
101            // string here
102            .current_catalog(self.catalog.read().unwrap().clone())
103            .mutable_session_data(self.mutable_inner.clone())
104            .sql_dialect(self.conn_info.channel.dialect())
105            .configuration_parameter(self.configuration_variables.clone())
106            .channel(self.conn_info.channel)
107            .process_id(self.process_id)
108            .conn_info(self.conn_info.clone())
109            .build()
110            .into()
111    }
112
113    pub fn conn_info(&self) -> &ConnInfo {
114        &self.conn_info
115    }
116
117    pub fn timezone(&self) -> Timezone {
118        self.mutable_inner.read().unwrap().timezone.clone()
119    }
120
121    pub fn read_preference(&self) -> ReadPreference {
122        self.mutable_inner.read().unwrap().read_preference
123    }
124
125    pub fn set_timezone(&self, tz: Timezone) {
126        let mut inner = self.mutable_inner.write().unwrap();
127        inner.timezone = tz;
128    }
129
130    pub fn set_read_preference(&self, read_preference: ReadPreference) {
131        self.mutable_inner.write().unwrap().read_preference = read_preference;
132    }
133
134    pub fn user_info(&self) -> UserInfoRef {
135        self.mutable_inner.read().unwrap().user_info.clone()
136    }
137
138    pub fn set_user_info(&self, user_info: UserInfoRef) {
139        self.mutable_inner.write().unwrap().user_info = user_info;
140    }
141
142    pub fn set_catalog(&self, catalog: String) {
143        *self.catalog.write().unwrap() = catalog;
144    }
145
146    pub fn catalog(&self) -> String {
147        self.catalog.read().unwrap().clone()
148    }
149
150    pub fn schema(&self) -> String {
151        self.mutable_inner.read().unwrap().schema.clone()
152    }
153
154    pub fn set_schema(&self, schema: String) {
155        self.mutable_inner.write().unwrap().schema = schema;
156    }
157
158    pub fn get_db_string(&self) -> String {
159        build_db_string(&self.catalog(), &self.schema())
160    }
161
162    pub fn process_id(&self) -> u32 {
163        self.process_id
164    }
165
166    pub fn warnings_count(&self) -> usize {
167        self.mutable_inner.read().unwrap().warnings.len()
168    }
169
170    pub fn warnings(&self) -> Vec<String> {
171        self.mutable_inner
172            .read()
173            .unwrap()
174            .warnings
175            .iter()
176            .cloned()
177            .collect()
178    }
179
180    /// Add a warning message. If the limit is reached, discard the oldest warning.
181    pub fn add_warning(&self, warning: String) {
182        let mut inner = self.mutable_inner.write().unwrap();
183        if inner.warnings.len() >= MAX_WARNINGS {
184            inner.warnings.pop_front();
185        }
186        inner.warnings.push_back(warning);
187    }
188
189    pub fn clear_warnings(&self) {
190        let mut inner = self.mutable_inner.write().unwrap();
191        if inner.warnings.is_empty() {
192            return;
193        }
194        inner.warnings.clear();
195    }
196}