session/
context.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::protocol_ctx::ProtocolCtx;
35use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
36use crate::{MutableInner, ReadPreference};
37
38pub type QueryContextRef = Arc<QueryContext>;
39pub type ConnInfoRef = Arc<ConnInfo>;
40
41const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
42
43#[derive(Debug, Builder, Clone)]
44#[builder(pattern = "owned")]
45#[builder(build_fn(skip))]
46pub struct QueryContext {
47    current_catalog: String,
48    /// mapping of RegionId to SequenceNumber, for snapshot read, meaning that the read should only
49    /// container data that was committed before(and include) the given sequence number
50    /// this field will only be filled if extensions contains a pair of "snapshot_read" and "true"
51    snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
52    /// Mappings of the RegionId to the minimal sequence of SST file to scan.
53    sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
54    // we use Arc<RwLock>> for modifiable fields
55    #[builder(default)]
56    mutable_session_data: Arc<RwLock<MutableInner>>,
57    #[builder(default)]
58    mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
59    sql_dialect: Arc<dyn Dialect + Send + Sync>,
60    #[builder(default)]
61    extensions: HashMap<String, String>,
62    /// The configuration parameter are used to store the parameters that are set by the user
63    #[builder(default)]
64    configuration_parameter: Arc<ConfigurationVariables>,
65    /// Track which protocol the query comes from.
66    #[builder(default)]
67    channel: Channel,
68    /// Process id for managing on-going queries
69    #[builder(default)]
70    process_id: u32,
71    /// Connection information
72    #[builder(default)]
73    conn_info: ConnInfo,
74    /// Protocol specific context
75    #[builder(default)]
76    protocol_ctx: ProtocolCtx,
77}
78
79/// This fields hold data that is only valid to current query context
80#[derive(Debug, Builder, Clone, Default)]
81pub struct QueryContextMutableFields {
82    warning: Option<String>,
83    // TODO: remove this when format is supported in datafusion
84    explain_format: Option<String>,
85    /// Explain options to control the verbose analyze output.
86    explain_options: Option<ExplainOptions>,
87}
88
89impl Display for QueryContext {
90    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91        write!(
92            f,
93            "QueryContext{{catalog: {}, schema: {}}}",
94            self.current_catalog(),
95            self.current_schema()
96        )
97    }
98}
99
100impl QueryContextBuilder {
101    pub fn current_schema(mut self, schema: String) -> Self {
102        if self.mutable_session_data.is_none() {
103            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
104        }
105
106        // safe for unwrap because previous none check
107        self.mutable_session_data
108            .as_mut()
109            .unwrap()
110            .write()
111            .unwrap()
112            .schema = schema;
113        self
114    }
115
116    pub fn timezone(mut self, timezone: Timezone) -> Self {
117        if self.mutable_session_data.is_none() {
118            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
119        }
120
121        self.mutable_session_data
122            .as_mut()
123            .unwrap()
124            .write()
125            .unwrap()
126            .timezone = timezone;
127        self
128    }
129
130    pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
131        self.mutable_query_context_data
132            .get_or_insert_default()
133            .write()
134            .unwrap()
135            .explain_options = explain_options;
136        self
137    }
138
139    pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
140        self.mutable_session_data
141            .get_or_insert_default()
142            .write()
143            .unwrap()
144            .read_preference = read_preference;
145        self
146    }
147}
148
149impl From<&RegionRequestHeader> for QueryContext {
150    fn from(value: &RegionRequestHeader) -> Self {
151        if let Some(ctx) = &value.query_context {
152            ctx.clone().into()
153        } else {
154            QueryContextBuilder::default().build()
155        }
156    }
157}
158
159impl From<api::v1::QueryContext> for QueryContext {
160    fn from(ctx: api::v1::QueryContext) -> Self {
161        let sequences = ctx.snapshot_seqs.as_ref();
162        QueryContextBuilder::default()
163            .current_catalog(ctx.current_catalog)
164            .current_schema(ctx.current_schema)
165            .timezone(parse_timezone(Some(&ctx.timezone)))
166            .extensions(ctx.extensions)
167            .channel(ctx.channel.into())
168            .snapshot_seqs(Arc::new(RwLock::new(
169                sequences
170                    .map(|x| x.snapshot_seqs.clone())
171                    .unwrap_or_default(),
172            )))
173            .sst_min_sequences(Arc::new(RwLock::new(
174                sequences
175                    .map(|x| x.sst_min_sequences.clone())
176                    .unwrap_or_default(),
177            )))
178            .explain_options(ctx.explain)
179            .build()
180    }
181}
182
183impl From<QueryContext> for api::v1::QueryContext {
184    fn from(
185        QueryContext {
186            current_catalog,
187            mutable_session_data: mutable_inner,
188            extensions,
189            channel,
190            snapshot_seqs,
191            sst_min_sequences,
192            mutable_query_context_data,
193            ..
194        }: QueryContext,
195    ) -> Self {
196        let explain = mutable_query_context_data.read().unwrap().explain_options;
197        let mutable_inner = mutable_inner.read().unwrap();
198        api::v1::QueryContext {
199            current_catalog,
200            current_schema: mutable_inner.schema.clone(),
201            timezone: mutable_inner.timezone.to_string(),
202            extensions,
203            channel: channel as u32,
204            snapshot_seqs: Some(api::v1::SnapshotSequences {
205                snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
206                sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
207            }),
208            explain,
209        }
210    }
211}
212
213impl From<&QueryContext> for api::v1::QueryContext {
214    fn from(ctx: &QueryContext) -> Self {
215        ctx.clone().into()
216    }
217}
218
219impl QueryContext {
220    pub fn arc() -> QueryContextRef {
221        Arc::new(QueryContextBuilder::default().build())
222    }
223
224    pub fn with(catalog: &str, schema: &str) -> QueryContext {
225        QueryContextBuilder::default()
226            .current_catalog(catalog.to_string())
227            .current_schema(schema.to_string())
228            .build()
229    }
230
231    pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
232        QueryContextBuilder::default()
233            .current_catalog(catalog.to_string())
234            .current_schema(schema.to_string())
235            .channel(channel)
236            .build()
237    }
238
239    pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
240        let (catalog, schema) = db_name
241            .map(|db| {
242                let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
243                (catalog, schema)
244            })
245            .unwrap_or_else(|| {
246                (
247                    DEFAULT_CATALOG_NAME.to_string(),
248                    DEFAULT_SCHEMA_NAME.to_string(),
249                )
250            });
251        QueryContextBuilder::default()
252            .current_catalog(catalog)
253            .current_schema(schema.to_string())
254            .build()
255    }
256
257    pub fn current_schema(&self) -> String {
258        self.mutable_session_data.read().unwrap().schema.clone()
259    }
260
261    pub fn set_current_schema(&self, new_schema: &str) {
262        self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
263    }
264
265    pub fn current_catalog(&self) -> &str {
266        &self.current_catalog
267    }
268
269    pub fn set_current_catalog(&mut self, new_catalog: &str) {
270        self.current_catalog = new_catalog.to_string();
271    }
272
273    pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
274        &*self.sql_dialect
275    }
276
277    pub fn get_db_string(&self) -> String {
278        let catalog = self.current_catalog();
279        let schema = self.current_schema();
280        build_db_string(catalog, &schema)
281    }
282
283    pub fn timezone(&self) -> Timezone {
284        self.mutable_session_data.read().unwrap().timezone.clone()
285    }
286
287    pub fn set_timezone(&self, timezone: Timezone) {
288        self.mutable_session_data.write().unwrap().timezone = timezone;
289    }
290
291    pub fn read_preference(&self) -> ReadPreference {
292        self.mutable_session_data.read().unwrap().read_preference
293    }
294
295    pub fn set_read_preference(&self, read_preference: ReadPreference) {
296        self.mutable_session_data.write().unwrap().read_preference = read_preference;
297    }
298
299    pub fn current_user(&self) -> UserInfoRef {
300        self.mutable_session_data.read().unwrap().user_info.clone()
301    }
302
303    pub fn set_current_user(&self, user: UserInfoRef) {
304        self.mutable_session_data.write().unwrap().user_info = user;
305    }
306
307    pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
308        self.extensions.insert(key.into(), value.into());
309    }
310
311    pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
312        self.extensions.get(key.as_ref()).map(|v| v.as_str())
313    }
314
315    pub fn extensions(&self) -> HashMap<String, String> {
316        self.extensions.clone()
317    }
318
319    /// Default to double quote and fallback to back quote
320    pub fn quote_style(&self) -> char {
321        if self.sql_dialect().is_delimited_identifier_start('"') {
322            '"'
323        } else if self.sql_dialect().is_delimited_identifier_start('\'') {
324            '\''
325        } else {
326            '`'
327        }
328    }
329
330    pub fn configuration_parameter(&self) -> &ConfigurationVariables {
331        &self.configuration_parameter
332    }
333
334    pub fn channel(&self) -> Channel {
335        self.channel
336    }
337
338    pub fn set_channel(&mut self, channel: Channel) {
339        self.channel = channel;
340    }
341
342    pub fn warning(&self) -> Option<String> {
343        self.mutable_query_context_data
344            .read()
345            .unwrap()
346            .warning
347            .clone()
348    }
349
350    pub fn set_warning(&self, msg: String) {
351        self.mutable_query_context_data.write().unwrap().warning = Some(msg);
352    }
353
354    pub fn explain_format(&self) -> Option<String> {
355        self.mutable_query_context_data
356            .read()
357            .unwrap()
358            .explain_format
359            .clone()
360    }
361
362    pub fn set_explain_format(&self, format: String) {
363        self.mutable_query_context_data
364            .write()
365            .unwrap()
366            .explain_format = Some(format);
367    }
368
369    pub fn explain_verbose(&self) -> bool {
370        self.mutable_query_context_data
371            .read()
372            .unwrap()
373            .explain_options
374            .map(|opts| opts.verbose)
375            .unwrap_or(false)
376    }
377
378    pub fn set_explain_verbose(&self, verbose: bool) {
379        self.mutable_query_context_data
380            .write()
381            .unwrap()
382            .explain_options
383            .get_or_insert_default()
384            .verbose = verbose;
385    }
386
387    pub fn query_timeout(&self) -> Option<Duration> {
388        self.mutable_session_data.read().unwrap().query_timeout
389    }
390
391    pub fn query_timeout_as_millis(&self) -> u128 {
392        let timeout = self.mutable_session_data.read().unwrap().query_timeout;
393        if let Some(t) = timeout {
394            return t.as_millis();
395        }
396        0
397    }
398
399    pub fn set_query_timeout(&self, timeout: Duration) {
400        self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
401    }
402
403    pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
404        let mut guard = self.mutable_session_data.write().unwrap();
405        guard.cursors.insert(name, Arc::new(rb));
406
407        let cursor_count = guard.cursors.len();
408        if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
409            warn!("Current connection has {} open cursors", cursor_count);
410        }
411    }
412
413    pub fn remove_cursor(&self, name: &str) {
414        let mut guard = self.mutable_session_data.write().unwrap();
415        guard.cursors.remove(name);
416    }
417
418    pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
419        let guard = self.mutable_session_data.read().unwrap();
420        let rb = guard.cursors.get(name);
421        rb.cloned()
422    }
423
424    pub fn snapshots(&self) -> HashMap<u64, u64> {
425        self.snapshot_seqs.read().unwrap().clone()
426    }
427
428    pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
429        self.snapshot_seqs.read().unwrap().get(&region_id).cloned()
430    }
431
432    /// Returns `true` if the session can cast strings to numbers in MySQL style.
433    pub fn auto_string_to_numeric(&self) -> bool {
434        matches!(self.channel, Channel::Mysql)
435    }
436
437    /// Finds the minimal sequence of SST files to scan of a Region.
438    pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
439        self.sst_min_sequences
440            .read()
441            .unwrap()
442            .get(&region_id)
443            .copied()
444    }
445
446    pub fn process_id(&self) -> u32 {
447        self.process_id
448    }
449
450    /// Get client information
451    pub fn conn_info(&self) -> &ConnInfo {
452        &self.conn_info
453    }
454
455    pub fn protocol_ctx(&self) -> &ProtocolCtx {
456        &self.protocol_ctx
457    }
458
459    pub fn set_protocol_ctx(&mut self, protocol_ctx: ProtocolCtx) {
460        self.protocol_ctx = protocol_ctx;
461    }
462}
463
464impl QueryContextBuilder {
465    pub fn build(self) -> QueryContext {
466        let channel = self.channel.unwrap_or_default();
467        QueryContext {
468            current_catalog: self
469                .current_catalog
470                .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
471            snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
472            sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
473            mutable_session_data: self.mutable_session_data.unwrap_or_default(),
474            mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
475            sql_dialect: self
476                .sql_dialect
477                .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
478            extensions: self.extensions.unwrap_or_default(),
479            configuration_parameter: self
480                .configuration_parameter
481                .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
482            channel,
483            process_id: self.process_id.unwrap_or_default(),
484            conn_info: self.conn_info.unwrap_or_default(),
485            protocol_ctx: self.protocol_ctx.unwrap_or_default(),
486        }
487    }
488
489    pub fn set_extension(mut self, key: String, value: String) -> Self {
490        self.extensions
491            .get_or_insert_with(HashMap::new)
492            .insert(key, value);
493        self
494    }
495}
496
497#[derive(Debug, Clone, Default)]
498pub struct ConnInfo {
499    pub client_addr: Option<SocketAddr>,
500    pub channel: Channel,
501}
502
503impl Display for ConnInfo {
504    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
505        write!(
506            f,
507            "{}[{}]",
508            self.channel,
509            self.client_addr
510                .map(|addr| addr.to_string())
511                .as_deref()
512                .unwrap_or("unknown client addr")
513        )
514    }
515}
516
517impl ConnInfo {
518    pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
519        Self {
520            client_addr,
521            channel,
522        }
523    }
524}
525
526#[derive(Debug, PartialEq, Default, Clone, Copy)]
527#[repr(u8)]
528pub enum Channel {
529    #[default]
530    Unknown = 0,
531
532    Mysql = 1,
533    Postgres = 2,
534    HttpSql = 3,
535    Prometheus = 4,
536    Otlp = 5,
537    Grpc = 6,
538    Influx = 7,
539    Opentsdb = 8,
540    Loki = 9,
541    Elasticsearch = 10,
542    Jaeger = 11,
543    Log = 12,
544    Promql = 13,
545}
546
547impl From<u32> for Channel {
548    fn from(value: u32) -> Self {
549        match value {
550            1 => Self::Mysql,
551            2 => Self::Postgres,
552            3 => Self::HttpSql,
553            4 => Self::Prometheus,
554            5 => Self::Otlp,
555            6 => Self::Grpc,
556            7 => Self::Influx,
557            8 => Self::Opentsdb,
558            9 => Self::Loki,
559            10 => Self::Elasticsearch,
560            11 => Self::Jaeger,
561            12 => Self::Log,
562            13 => Self::Promql,
563            _ => Self::Unknown,
564        }
565    }
566}
567
568impl Channel {
569    pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
570        match self {
571            Channel::Mysql => Arc::new(MySqlDialect {}),
572            Channel::Postgres => Arc::new(PostgreSqlDialect {}),
573            _ => Arc::new(GenericDialect {}),
574        }
575    }
576}
577
578impl Display for Channel {
579    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
580        write!(f, "{}", self.as_ref())
581    }
582}
583
584impl AsRef<str> for Channel {
585    fn as_ref(&self) -> &str {
586        match self {
587            Channel::Mysql => "mysql",
588            Channel::Postgres => "postgres",
589            Channel::HttpSql => "httpsql",
590            Channel::Prometheus => "prometheus",
591            Channel::Otlp => "otlp",
592            Channel::Grpc => "grpc",
593            Channel::Influx => "influx",
594            Channel::Opentsdb => "opentsdb",
595            Channel::Loki => "loki",
596            Channel::Elasticsearch => "elasticsearch",
597            Channel::Jaeger => "jaeger",
598            Channel::Log => "log",
599            Channel::Promql => "promql",
600            Channel::Unknown => "unknown",
601        }
602    }
603}
604
605#[derive(Default, Debug)]
606pub struct ConfigurationVariables {
607    postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
608    pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
609    allow_query_fallback: ArcSwap<bool>,
610}
611
612impl Clone for ConfigurationVariables {
613    fn clone(&self) -> Self {
614        Self {
615            postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
616            pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
617            allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
618        }
619    }
620}
621
622impl ConfigurationVariables {
623    pub fn new() -> Self {
624        Self::default()
625    }
626
627    pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
628        let _ = self.postgres_bytea_output.swap(Arc::new(value));
629    }
630
631    pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
632        self.postgres_bytea_output.load().clone()
633    }
634
635    pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
636        self.pg_datestyle_format.load().clone()
637    }
638
639    pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
640        self.pg_datestyle_format.swap(Arc::new((style, order)));
641    }
642
643    pub fn allow_query_fallback(&self) -> bool {
644        **self.allow_query_fallback.load()
645    }
646
647    pub fn set_allow_query_fallback(&self, allow: bool) {
648        self.allow_query_fallback.swap(Arc::new(allow));
649    }
650}
651
652#[cfg(test)]
653mod test {
654    use common_catalog::consts::DEFAULT_CATALOG_NAME;
655
656    use super::*;
657    use crate::context::Channel;
658    use crate::Session;
659
660    #[test]
661    fn test_session() {
662        let session = Session::new(
663            Some("127.0.0.1:9000".parse().unwrap()),
664            Channel::Mysql,
665            Default::default(),
666            100,
667        );
668        // test user_info
669        assert_eq!(session.user_info().username(), "greptime");
670
671        // test channel
672        assert_eq!(session.conn_info().channel, Channel::Mysql);
673        let client_addr = session.conn_info().client_addr.as_ref().unwrap();
674        assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
675        assert_eq!(client_addr.port(), 9000);
676
677        assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
678        assert_eq!(100, session.process_id());
679    }
680
681    #[test]
682    fn test_context_db_string() {
683        let context = QueryContext::with("a0b1c2d3", "test");
684        assert_eq!("a0b1c2d3-test", context.get_db_string());
685
686        let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
687        assert_eq!("test", context.get_db_string());
688    }
689}