session/
context.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
35use crate::{MutableInner, ReadPreference};
36
37pub type QueryContextRef = Arc<QueryContext>;
38pub type ConnInfoRef = Arc<ConnInfo>;
39
40const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
41
42#[derive(Debug, Builder, Clone)]
43#[builder(pattern = "owned")]
44#[builder(build_fn(skip))]
45pub struct QueryContext {
46    current_catalog: String,
47    /// mapping of RegionId to SequenceNumber, for snapshot read, meaning that the read should only
48    /// container data that was committed before(and include) the given sequence number
49    /// this field will only be filled if extensions contains a pair of "snapshot_read" and "true"
50    snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
51    /// Mappings of the RegionId to the minimal sequence of SST file to scan.
52    sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
53    // we use Arc<RwLock>> for modifiable fields
54    #[builder(default)]
55    mutable_session_data: Arc<RwLock<MutableInner>>,
56    #[builder(default)]
57    mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
58    sql_dialect: Arc<dyn Dialect + Send + Sync>,
59    #[builder(default)]
60    extensions: HashMap<String, String>,
61    // The configuration parameter are used to store the parameters that are set by the user
62    #[builder(default)]
63    configuration_parameter: Arc<ConfigurationVariables>,
64    // Track which protocol the query comes from.
65    #[builder(default)]
66    channel: Channel,
67}
68
69/// This fields hold data that is only valid to current query context
70#[derive(Debug, Builder, Clone, Default)]
71pub struct QueryContextMutableFields {
72    warning: Option<String>,
73    // TODO: remove this when format is supported in datafusion
74    explain_format: Option<String>,
75    /// Explain options to control the verbose analyze output.
76    explain_options: Option<ExplainOptions>,
77}
78
79impl Display for QueryContext {
80    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81        write!(
82            f,
83            "QueryContext{{catalog: {}, schema: {}}}",
84            self.current_catalog(),
85            self.current_schema()
86        )
87    }
88}
89
90impl QueryContextBuilder {
91    pub fn current_schema(mut self, schema: String) -> Self {
92        if self.mutable_session_data.is_none() {
93            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
94        }
95
96        // safe for unwrap because previous none check
97        self.mutable_session_data
98            .as_mut()
99            .unwrap()
100            .write()
101            .unwrap()
102            .schema = schema;
103        self
104    }
105
106    pub fn timezone(mut self, timezone: Timezone) -> Self {
107        if self.mutable_session_data.is_none() {
108            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
109        }
110
111        self.mutable_session_data
112            .as_mut()
113            .unwrap()
114            .write()
115            .unwrap()
116            .timezone = timezone;
117        self
118    }
119
120    pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
121        self.mutable_query_context_data
122            .get_or_insert_default()
123            .write()
124            .unwrap()
125            .explain_options = explain_options;
126        self
127    }
128
129    pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
130        self.mutable_session_data
131            .get_or_insert_default()
132            .write()
133            .unwrap()
134            .read_preference = read_preference;
135        self
136    }
137}
138
139impl From<&RegionRequestHeader> for QueryContext {
140    fn from(value: &RegionRequestHeader) -> Self {
141        if let Some(ctx) = &value.query_context {
142            ctx.clone().into()
143        } else {
144            QueryContextBuilder::default().build()
145        }
146    }
147}
148
149impl From<api::v1::QueryContext> for QueryContext {
150    fn from(ctx: api::v1::QueryContext) -> Self {
151        let sequences = ctx.snapshot_seqs.as_ref();
152        QueryContextBuilder::default()
153            .current_catalog(ctx.current_catalog)
154            .current_schema(ctx.current_schema)
155            .timezone(parse_timezone(Some(&ctx.timezone)))
156            .extensions(ctx.extensions)
157            .channel(ctx.channel.into())
158            .snapshot_seqs(Arc::new(RwLock::new(
159                sequences
160                    .map(|x| x.snapshot_seqs.clone())
161                    .unwrap_or_default(),
162            )))
163            .sst_min_sequences(Arc::new(RwLock::new(
164                sequences
165                    .map(|x| x.sst_min_sequences.clone())
166                    .unwrap_or_default(),
167            )))
168            .explain_options(ctx.explain)
169            .build()
170    }
171}
172
173impl From<QueryContext> for api::v1::QueryContext {
174    fn from(
175        QueryContext {
176            current_catalog,
177            mutable_session_data: mutable_inner,
178            extensions,
179            channel,
180            snapshot_seqs,
181            sst_min_sequences,
182            mutable_query_context_data,
183            ..
184        }: QueryContext,
185    ) -> Self {
186        let explain = mutable_query_context_data.read().unwrap().explain_options;
187        let mutable_inner = mutable_inner.read().unwrap();
188        api::v1::QueryContext {
189            current_catalog,
190            current_schema: mutable_inner.schema.clone(),
191            timezone: mutable_inner.timezone.to_string(),
192            extensions,
193            channel: channel as u32,
194            snapshot_seqs: Some(api::v1::SnapshotSequences {
195                snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
196                sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
197            }),
198            explain,
199        }
200    }
201}
202
203impl From<&QueryContext> for api::v1::QueryContext {
204    fn from(ctx: &QueryContext) -> Self {
205        ctx.clone().into()
206    }
207}
208
209impl QueryContext {
210    pub fn arc() -> QueryContextRef {
211        Arc::new(QueryContextBuilder::default().build())
212    }
213
214    pub fn with(catalog: &str, schema: &str) -> QueryContext {
215        QueryContextBuilder::default()
216            .current_catalog(catalog.to_string())
217            .current_schema(schema.to_string())
218            .build()
219    }
220
221    pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
222        QueryContextBuilder::default()
223            .current_catalog(catalog.to_string())
224            .current_schema(schema.to_string())
225            .channel(channel)
226            .build()
227    }
228
229    pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
230        let (catalog, schema) = db_name
231            .map(|db| {
232                let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
233                (catalog, schema)
234            })
235            .unwrap_or_else(|| {
236                (
237                    DEFAULT_CATALOG_NAME.to_string(),
238                    DEFAULT_SCHEMA_NAME.to_string(),
239                )
240            });
241        QueryContextBuilder::default()
242            .current_catalog(catalog)
243            .current_schema(schema.to_string())
244            .build()
245    }
246
247    pub fn current_schema(&self) -> String {
248        self.mutable_session_data.read().unwrap().schema.clone()
249    }
250
251    pub fn set_current_schema(&self, new_schema: &str) {
252        self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
253    }
254
255    pub fn current_catalog(&self) -> &str {
256        &self.current_catalog
257    }
258
259    pub fn set_current_catalog(&mut self, new_catalog: &str) {
260        self.current_catalog = new_catalog.to_string();
261    }
262
263    pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
264        &*self.sql_dialect
265    }
266
267    pub fn get_db_string(&self) -> String {
268        let catalog = self.current_catalog();
269        let schema = self.current_schema();
270        build_db_string(catalog, &schema)
271    }
272
273    pub fn timezone(&self) -> Timezone {
274        self.mutable_session_data.read().unwrap().timezone.clone()
275    }
276
277    pub fn set_timezone(&self, timezone: Timezone) {
278        self.mutable_session_data.write().unwrap().timezone = timezone;
279    }
280
281    pub fn read_preference(&self) -> ReadPreference {
282        self.mutable_session_data.read().unwrap().read_preference
283    }
284
285    pub fn set_read_preference(&self, read_preference: ReadPreference) {
286        self.mutable_session_data.write().unwrap().read_preference = read_preference;
287    }
288
289    pub fn current_user(&self) -> UserInfoRef {
290        self.mutable_session_data.read().unwrap().user_info.clone()
291    }
292
293    pub fn set_current_user(&self, user: UserInfoRef) {
294        self.mutable_session_data.write().unwrap().user_info = user;
295    }
296
297    pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
298        self.extensions.insert(key.into(), value.into());
299    }
300
301    pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
302        self.extensions.get(key.as_ref()).map(|v| v.as_str())
303    }
304
305    pub fn extensions(&self) -> HashMap<String, String> {
306        self.extensions.clone()
307    }
308
309    /// Default to double quote and fallback to back quote
310    pub fn quote_style(&self) -> char {
311        if self.sql_dialect().is_delimited_identifier_start('"') {
312            '"'
313        } else if self.sql_dialect().is_delimited_identifier_start('\'') {
314            '\''
315        } else {
316            '`'
317        }
318    }
319
320    pub fn configuration_parameter(&self) -> &ConfigurationVariables {
321        &self.configuration_parameter
322    }
323
324    pub fn channel(&self) -> Channel {
325        self.channel
326    }
327
328    pub fn set_channel(&mut self, channel: Channel) {
329        self.channel = channel;
330    }
331
332    pub fn warning(&self) -> Option<String> {
333        self.mutable_query_context_data
334            .read()
335            .unwrap()
336            .warning
337            .clone()
338    }
339
340    pub fn set_warning(&self, msg: String) {
341        self.mutable_query_context_data.write().unwrap().warning = Some(msg);
342    }
343
344    pub fn explain_format(&self) -> Option<String> {
345        self.mutable_query_context_data
346            .read()
347            .unwrap()
348            .explain_format
349            .clone()
350    }
351
352    pub fn set_explain_format(&self, format: String) {
353        self.mutable_query_context_data
354            .write()
355            .unwrap()
356            .explain_format = Some(format);
357    }
358
359    pub fn explain_verbose(&self) -> bool {
360        self.mutable_query_context_data
361            .read()
362            .unwrap()
363            .explain_options
364            .map(|opts| opts.verbose)
365            .unwrap_or(false)
366    }
367
368    pub fn set_explain_verbose(&self, verbose: bool) {
369        self.mutable_query_context_data
370            .write()
371            .unwrap()
372            .explain_options
373            .get_or_insert_default()
374            .verbose = verbose;
375    }
376
377    pub fn query_timeout(&self) -> Option<Duration> {
378        self.mutable_session_data.read().unwrap().query_timeout
379    }
380
381    pub fn query_timeout_as_millis(&self) -> u128 {
382        let timeout = self.mutable_session_data.read().unwrap().query_timeout;
383        if let Some(t) = timeout {
384            return t.as_millis();
385        }
386        0
387    }
388
389    pub fn set_query_timeout(&self, timeout: Duration) {
390        self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
391    }
392
393    pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
394        let mut guard = self.mutable_session_data.write().unwrap();
395        guard.cursors.insert(name, Arc::new(rb));
396
397        let cursor_count = guard.cursors.len();
398        if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
399            warn!("Current connection has {} open cursors", cursor_count);
400        }
401    }
402
403    pub fn remove_cursor(&self, name: &str) {
404        let mut guard = self.mutable_session_data.write().unwrap();
405        guard.cursors.remove(name);
406    }
407
408    pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
409        let guard = self.mutable_session_data.read().unwrap();
410        let rb = guard.cursors.get(name);
411        rb.cloned()
412    }
413
414    pub fn snapshots(&self) -> HashMap<u64, u64> {
415        self.snapshot_seqs.read().unwrap().clone()
416    }
417
418    pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
419        self.snapshot_seqs.read().unwrap().get(&region_id).cloned()
420    }
421
422    /// Returns `true` if the session can cast strings to numbers in MySQL style.
423    pub fn auto_string_to_numeric(&self) -> bool {
424        matches!(self.channel, Channel::Mysql)
425    }
426
427    /// Finds the minimal sequence of SST files to scan of a Region.
428    pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
429        self.sst_min_sequences
430            .read()
431            .unwrap()
432            .get(&region_id)
433            .copied()
434    }
435}
436
437impl QueryContextBuilder {
438    pub fn build(self) -> QueryContext {
439        let channel = self.channel.unwrap_or_default();
440        QueryContext {
441            current_catalog: self
442                .current_catalog
443                .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
444            snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
445            sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
446            mutable_session_data: self.mutable_session_data.unwrap_or_default(),
447            mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
448            sql_dialect: self
449                .sql_dialect
450                .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
451            extensions: self.extensions.unwrap_or_default(),
452            configuration_parameter: self
453                .configuration_parameter
454                .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
455            channel,
456        }
457    }
458
459    pub fn set_extension(mut self, key: String, value: String) -> Self {
460        self.extensions
461            .get_or_insert_with(HashMap::new)
462            .insert(key, value);
463        self
464    }
465}
466
467#[derive(Debug)]
468pub struct ConnInfo {
469    pub client_addr: Option<SocketAddr>,
470    pub channel: Channel,
471}
472
473impl Display for ConnInfo {
474    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
475        write!(
476            f,
477            "{}[{}]",
478            self.channel,
479            self.client_addr
480                .map(|addr| addr.to_string())
481                .as_deref()
482                .unwrap_or("unknown client addr")
483        )
484    }
485}
486
487impl ConnInfo {
488    pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
489        Self {
490            client_addr,
491            channel,
492        }
493    }
494}
495
496#[derive(Debug, PartialEq, Default, Clone, Copy)]
497#[repr(u8)]
498pub enum Channel {
499    #[default]
500    Unknown = 0,
501
502    Mysql = 1,
503    Postgres = 2,
504    Http = 3,
505    Prometheus = 4,
506    Otlp = 5,
507    Grpc = 6,
508    Influx = 7,
509    Opentsdb = 8,
510    Loki = 9,
511    Elasticsearch = 10,
512    Jaeger = 11,
513}
514
515impl From<u32> for Channel {
516    fn from(value: u32) -> Self {
517        match value {
518            1 => Self::Mysql,
519            2 => Self::Postgres,
520            3 => Self::Http,
521            4 => Self::Prometheus,
522            5 => Self::Otlp,
523            6 => Self::Grpc,
524            7 => Self::Influx,
525            8 => Self::Opentsdb,
526            9 => Self::Loki,
527            10 => Self::Elasticsearch,
528            11 => Self::Jaeger,
529            _ => Self::Unknown,
530        }
531    }
532}
533
534impl Channel {
535    pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
536        match self {
537            Channel::Mysql => Arc::new(MySqlDialect {}),
538            Channel::Postgres => Arc::new(PostgreSqlDialect {}),
539            _ => Arc::new(GenericDialect {}),
540        }
541    }
542}
543
544impl Display for Channel {
545    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
546        match self {
547            Channel::Mysql => write!(f, "mysql"),
548            Channel::Postgres => write!(f, "postgres"),
549            Channel::Http => write!(f, "http"),
550            Channel::Prometheus => write!(f, "prometheus"),
551            Channel::Otlp => write!(f, "otlp"),
552            Channel::Grpc => write!(f, "grpc"),
553            Channel::Influx => write!(f, "influx"),
554            Channel::Opentsdb => write!(f, "opentsdb"),
555            Channel::Loki => write!(f, "loki"),
556            Channel::Elasticsearch => write!(f, "elasticsearch"),
557            Channel::Jaeger => write!(f, "jaeger"),
558            Channel::Unknown => write!(f, "unknown"),
559        }
560    }
561}
562
563#[derive(Default, Debug)]
564pub struct ConfigurationVariables {
565    postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
566    pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
567}
568
569impl Clone for ConfigurationVariables {
570    fn clone(&self) -> Self {
571        Self {
572            postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
573            pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
574        }
575    }
576}
577
578impl ConfigurationVariables {
579    pub fn new() -> Self {
580        Self::default()
581    }
582
583    pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
584        let _ = self.postgres_bytea_output.swap(Arc::new(value));
585    }
586
587    pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
588        self.postgres_bytea_output.load().clone()
589    }
590
591    pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
592        self.pg_datestyle_format.load().clone()
593    }
594
595    pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
596        self.pg_datestyle_format.swap(Arc::new((style, order)));
597    }
598}
599
600#[cfg(test)]
601mod test {
602    use common_catalog::consts::DEFAULT_CATALOG_NAME;
603
604    use super::*;
605    use crate::context::Channel;
606    use crate::Session;
607
608    #[test]
609    fn test_session() {
610        let session = Session::new(
611            Some("127.0.0.1:9000".parse().unwrap()),
612            Channel::Mysql,
613            Default::default(),
614        );
615        // test user_info
616        assert_eq!(session.user_info().username(), "greptime");
617
618        // test channel
619        assert_eq!(session.conn_info().channel, Channel::Mysql);
620        let client_addr = session.conn_info().client_addr.as_ref().unwrap();
621        assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
622        assert_eq!(client_addr.port(), 9000);
623
624        assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
625    }
626
627    #[test]
628    fn test_context_db_string() {
629        let context = QueryContext::with("a0b1c2d3", "test");
630        assert_eq!("a0b1c2d3-test", context.get_db_string());
631
632        let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
633        assert_eq!("test", context.get_db_string());
634    }
635}