session/
context.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
35use crate::{MutableInner, ReadPreference};
36
37pub type QueryContextRef = Arc<QueryContext>;
38pub type ConnInfoRef = Arc<ConnInfo>;
39
40const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
41
42#[derive(Debug, Builder, Clone)]
43#[builder(pattern = "owned")]
44#[builder(build_fn(skip))]
45pub struct QueryContext {
46    current_catalog: String,
47    /// mapping of RegionId to SequenceNumber, for snapshot read, meaning that the read should only
48    /// container data that was committed before(and include) the given sequence number
49    /// this field will only be filled if extensions contains a pair of "snapshot_read" and "true"
50    snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
51    /// Mappings of the RegionId to the minimal sequence of SST file to scan.
52    sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
53    // we use Arc<RwLock>> for modifiable fields
54    #[builder(default)]
55    mutable_session_data: Arc<RwLock<MutableInner>>,
56    #[builder(default)]
57    mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
58    sql_dialect: Arc<dyn Dialect + Send + Sync>,
59    #[builder(default)]
60    extensions: HashMap<String, String>,
61    /// The configuration parameter are used to store the parameters that are set by the user
62    #[builder(default)]
63    configuration_parameter: Arc<ConfigurationVariables>,
64    /// Track which protocol the query comes from.
65    #[builder(default)]
66    channel: Channel,
67    /// Process id for managing on-going queries
68    #[builder(default)]
69    process_id: u32,
70    /// Connection information
71    #[builder(default)]
72    conn_info: ConnInfo,
73}
74
75/// This fields hold data that is only valid to current query context
76#[derive(Debug, Builder, Clone, Default)]
77pub struct QueryContextMutableFields {
78    warning: Option<String>,
79    // TODO: remove this when format is supported in datafusion
80    explain_format: Option<String>,
81    /// Explain options to control the verbose analyze output.
82    explain_options: Option<ExplainOptions>,
83}
84
85impl Display for QueryContext {
86    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87        write!(
88            f,
89            "QueryContext{{catalog: {}, schema: {}}}",
90            self.current_catalog(),
91            self.current_schema()
92        )
93    }
94}
95
96impl QueryContextBuilder {
97    pub fn current_schema(mut self, schema: String) -> Self {
98        if self.mutable_session_data.is_none() {
99            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
100        }
101
102        // safe for unwrap because previous none check
103        self.mutable_session_data
104            .as_mut()
105            .unwrap()
106            .write()
107            .unwrap()
108            .schema = schema;
109        self
110    }
111
112    pub fn timezone(mut self, timezone: Timezone) -> Self {
113        if self.mutable_session_data.is_none() {
114            self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
115        }
116
117        self.mutable_session_data
118            .as_mut()
119            .unwrap()
120            .write()
121            .unwrap()
122            .timezone = timezone;
123        self
124    }
125
126    pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
127        self.mutable_query_context_data
128            .get_or_insert_default()
129            .write()
130            .unwrap()
131            .explain_options = explain_options;
132        self
133    }
134
135    pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
136        self.mutable_session_data
137            .get_or_insert_default()
138            .write()
139            .unwrap()
140            .read_preference = read_preference;
141        self
142    }
143}
144
145impl From<&RegionRequestHeader> for QueryContext {
146    fn from(value: &RegionRequestHeader) -> Self {
147        if let Some(ctx) = &value.query_context {
148            ctx.clone().into()
149        } else {
150            QueryContextBuilder::default().build()
151        }
152    }
153}
154
155impl From<api::v1::QueryContext> for QueryContext {
156    fn from(ctx: api::v1::QueryContext) -> Self {
157        let sequences = ctx.snapshot_seqs.as_ref();
158        QueryContextBuilder::default()
159            .current_catalog(ctx.current_catalog)
160            .current_schema(ctx.current_schema)
161            .timezone(parse_timezone(Some(&ctx.timezone)))
162            .extensions(ctx.extensions)
163            .channel(ctx.channel.into())
164            .snapshot_seqs(Arc::new(RwLock::new(
165                sequences
166                    .map(|x| x.snapshot_seqs.clone())
167                    .unwrap_or_default(),
168            )))
169            .sst_min_sequences(Arc::new(RwLock::new(
170                sequences
171                    .map(|x| x.sst_min_sequences.clone())
172                    .unwrap_or_default(),
173            )))
174            .explain_options(ctx.explain)
175            .build()
176    }
177}
178
179impl From<QueryContext> for api::v1::QueryContext {
180    fn from(
181        QueryContext {
182            current_catalog,
183            mutable_session_data: mutable_inner,
184            extensions,
185            channel,
186            snapshot_seqs,
187            sst_min_sequences,
188            mutable_query_context_data,
189            ..
190        }: QueryContext,
191    ) -> Self {
192        let explain = mutable_query_context_data.read().unwrap().explain_options;
193        let mutable_inner = mutable_inner.read().unwrap();
194        api::v1::QueryContext {
195            current_catalog,
196            current_schema: mutable_inner.schema.clone(),
197            timezone: mutable_inner.timezone.to_string(),
198            extensions,
199            channel: channel as u32,
200            snapshot_seqs: Some(api::v1::SnapshotSequences {
201                snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
202                sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
203            }),
204            explain,
205        }
206    }
207}
208
209impl From<&QueryContext> for api::v1::QueryContext {
210    fn from(ctx: &QueryContext) -> Self {
211        ctx.clone().into()
212    }
213}
214
215impl QueryContext {
216    pub fn arc() -> QueryContextRef {
217        Arc::new(QueryContextBuilder::default().build())
218    }
219
220    pub fn with(catalog: &str, schema: &str) -> QueryContext {
221        QueryContextBuilder::default()
222            .current_catalog(catalog.to_string())
223            .current_schema(schema.to_string())
224            .build()
225    }
226
227    pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
228        QueryContextBuilder::default()
229            .current_catalog(catalog.to_string())
230            .current_schema(schema.to_string())
231            .channel(channel)
232            .build()
233    }
234
235    pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
236        let (catalog, schema) = db_name
237            .map(|db| {
238                let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
239                (catalog, schema)
240            })
241            .unwrap_or_else(|| {
242                (
243                    DEFAULT_CATALOG_NAME.to_string(),
244                    DEFAULT_SCHEMA_NAME.to_string(),
245                )
246            });
247        QueryContextBuilder::default()
248            .current_catalog(catalog)
249            .current_schema(schema.to_string())
250            .build()
251    }
252
253    pub fn current_schema(&self) -> String {
254        self.mutable_session_data.read().unwrap().schema.clone()
255    }
256
257    pub fn set_current_schema(&self, new_schema: &str) {
258        self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
259    }
260
261    pub fn current_catalog(&self) -> &str {
262        &self.current_catalog
263    }
264
265    pub fn set_current_catalog(&mut self, new_catalog: &str) {
266        self.current_catalog = new_catalog.to_string();
267    }
268
269    pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
270        &*self.sql_dialect
271    }
272
273    pub fn get_db_string(&self) -> String {
274        let catalog = self.current_catalog();
275        let schema = self.current_schema();
276        build_db_string(catalog, &schema)
277    }
278
279    pub fn timezone(&self) -> Timezone {
280        self.mutable_session_data.read().unwrap().timezone.clone()
281    }
282
283    pub fn set_timezone(&self, timezone: Timezone) {
284        self.mutable_session_data.write().unwrap().timezone = timezone;
285    }
286
287    pub fn read_preference(&self) -> ReadPreference {
288        self.mutable_session_data.read().unwrap().read_preference
289    }
290
291    pub fn set_read_preference(&self, read_preference: ReadPreference) {
292        self.mutable_session_data.write().unwrap().read_preference = read_preference;
293    }
294
295    pub fn current_user(&self) -> UserInfoRef {
296        self.mutable_session_data.read().unwrap().user_info.clone()
297    }
298
299    pub fn set_current_user(&self, user: UserInfoRef) {
300        self.mutable_session_data.write().unwrap().user_info = user;
301    }
302
303    pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
304        self.extensions.insert(key.into(), value.into());
305    }
306
307    pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
308        self.extensions.get(key.as_ref()).map(|v| v.as_str())
309    }
310
311    pub fn extensions(&self) -> HashMap<String, String> {
312        self.extensions.clone()
313    }
314
315    /// Default to double quote and fallback to back quote
316    pub fn quote_style(&self) -> char {
317        if self.sql_dialect().is_delimited_identifier_start('"') {
318            '"'
319        } else if self.sql_dialect().is_delimited_identifier_start('\'') {
320            '\''
321        } else {
322            '`'
323        }
324    }
325
326    pub fn configuration_parameter(&self) -> &ConfigurationVariables {
327        &self.configuration_parameter
328    }
329
330    pub fn channel(&self) -> Channel {
331        self.channel
332    }
333
334    pub fn set_channel(&mut self, channel: Channel) {
335        self.channel = channel;
336    }
337
338    pub fn warning(&self) -> Option<String> {
339        self.mutable_query_context_data
340            .read()
341            .unwrap()
342            .warning
343            .clone()
344    }
345
346    pub fn set_warning(&self, msg: String) {
347        self.mutable_query_context_data.write().unwrap().warning = Some(msg);
348    }
349
350    pub fn explain_format(&self) -> Option<String> {
351        self.mutable_query_context_data
352            .read()
353            .unwrap()
354            .explain_format
355            .clone()
356    }
357
358    pub fn set_explain_format(&self, format: String) {
359        self.mutable_query_context_data
360            .write()
361            .unwrap()
362            .explain_format = Some(format);
363    }
364
365    pub fn explain_verbose(&self) -> bool {
366        self.mutable_query_context_data
367            .read()
368            .unwrap()
369            .explain_options
370            .map(|opts| opts.verbose)
371            .unwrap_or(false)
372    }
373
374    pub fn set_explain_verbose(&self, verbose: bool) {
375        self.mutable_query_context_data
376            .write()
377            .unwrap()
378            .explain_options
379            .get_or_insert_default()
380            .verbose = verbose;
381    }
382
383    pub fn query_timeout(&self) -> Option<Duration> {
384        self.mutable_session_data.read().unwrap().query_timeout
385    }
386
387    pub fn query_timeout_as_millis(&self) -> u128 {
388        let timeout = self.mutable_session_data.read().unwrap().query_timeout;
389        if let Some(t) = timeout {
390            return t.as_millis();
391        }
392        0
393    }
394
395    pub fn set_query_timeout(&self, timeout: Duration) {
396        self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
397    }
398
399    pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
400        let mut guard = self.mutable_session_data.write().unwrap();
401        guard.cursors.insert(name, Arc::new(rb));
402
403        let cursor_count = guard.cursors.len();
404        if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
405            warn!("Current connection has {} open cursors", cursor_count);
406        }
407    }
408
409    pub fn remove_cursor(&self, name: &str) {
410        let mut guard = self.mutable_session_data.write().unwrap();
411        guard.cursors.remove(name);
412    }
413
414    pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
415        let guard = self.mutable_session_data.read().unwrap();
416        let rb = guard.cursors.get(name);
417        rb.cloned()
418    }
419
420    pub fn snapshots(&self) -> HashMap<u64, u64> {
421        self.snapshot_seqs.read().unwrap().clone()
422    }
423
424    pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
425        self.snapshot_seqs.read().unwrap().get(&region_id).cloned()
426    }
427
428    /// Returns `true` if the session can cast strings to numbers in MySQL style.
429    pub fn auto_string_to_numeric(&self) -> bool {
430        matches!(self.channel, Channel::Mysql)
431    }
432
433    /// Finds the minimal sequence of SST files to scan of a Region.
434    pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
435        self.sst_min_sequences
436            .read()
437            .unwrap()
438            .get(&region_id)
439            .copied()
440    }
441
442    pub fn process_id(&self) -> u32 {
443        self.process_id
444    }
445
446    /// Get client information
447    pub fn conn_info(&self) -> &ConnInfo {
448        &self.conn_info
449    }
450}
451
452impl QueryContextBuilder {
453    pub fn build(self) -> QueryContext {
454        let channel = self.channel.unwrap_or_default();
455        QueryContext {
456            current_catalog: self
457                .current_catalog
458                .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
459            snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
460            sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
461            mutable_session_data: self.mutable_session_data.unwrap_or_default(),
462            mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
463            sql_dialect: self
464                .sql_dialect
465                .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
466            extensions: self.extensions.unwrap_or_default(),
467            configuration_parameter: self
468                .configuration_parameter
469                .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
470            channel,
471            process_id: self.process_id.unwrap_or_default(),
472            conn_info: self.conn_info.unwrap_or_default(),
473        }
474    }
475
476    pub fn set_extension(mut self, key: String, value: String) -> Self {
477        self.extensions
478            .get_or_insert_with(HashMap::new)
479            .insert(key, value);
480        self
481    }
482}
483
484#[derive(Debug, Clone, Default)]
485pub struct ConnInfo {
486    pub client_addr: Option<SocketAddr>,
487    pub channel: Channel,
488}
489
490impl Display for ConnInfo {
491    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
492        write!(
493            f,
494            "{}[{}]",
495            self.channel,
496            self.client_addr
497                .map(|addr| addr.to_string())
498                .as_deref()
499                .unwrap_or("unknown client addr")
500        )
501    }
502}
503
504impl ConnInfo {
505    pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
506        Self {
507            client_addr,
508            channel,
509        }
510    }
511}
512
513#[derive(Debug, PartialEq, Default, Clone, Copy)]
514#[repr(u8)]
515pub enum Channel {
516    #[default]
517    Unknown = 0,
518
519    Mysql = 1,
520    Postgres = 2,
521    Http = 3,
522    Prometheus = 4,
523    Otlp = 5,
524    Grpc = 6,
525    Influx = 7,
526    Opentsdb = 8,
527    Loki = 9,
528    Elasticsearch = 10,
529    Jaeger = 11,
530}
531
532impl From<u32> for Channel {
533    fn from(value: u32) -> Self {
534        match value {
535            1 => Self::Mysql,
536            2 => Self::Postgres,
537            3 => Self::Http,
538            4 => Self::Prometheus,
539            5 => Self::Otlp,
540            6 => Self::Grpc,
541            7 => Self::Influx,
542            8 => Self::Opentsdb,
543            9 => Self::Loki,
544            10 => Self::Elasticsearch,
545            11 => Self::Jaeger,
546            _ => Self::Unknown,
547        }
548    }
549}
550
551impl Channel {
552    pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
553        match self {
554            Channel::Mysql => Arc::new(MySqlDialect {}),
555            Channel::Postgres => Arc::new(PostgreSqlDialect {}),
556            _ => Arc::new(GenericDialect {}),
557        }
558    }
559}
560
561impl Display for Channel {
562    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
563        match self {
564            Channel::Mysql => write!(f, "mysql"),
565            Channel::Postgres => write!(f, "postgres"),
566            Channel::Http => write!(f, "http"),
567            Channel::Prometheus => write!(f, "prometheus"),
568            Channel::Otlp => write!(f, "otlp"),
569            Channel::Grpc => write!(f, "grpc"),
570            Channel::Influx => write!(f, "influx"),
571            Channel::Opentsdb => write!(f, "opentsdb"),
572            Channel::Loki => write!(f, "loki"),
573            Channel::Elasticsearch => write!(f, "elasticsearch"),
574            Channel::Jaeger => write!(f, "jaeger"),
575            Channel::Unknown => write!(f, "unknown"),
576        }
577    }
578}
579
580#[derive(Default, Debug)]
581pub struct ConfigurationVariables {
582    postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
583    pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
584}
585
586impl Clone for ConfigurationVariables {
587    fn clone(&self) -> Self {
588        Self {
589            postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
590            pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
591        }
592    }
593}
594
595impl ConfigurationVariables {
596    pub fn new() -> Self {
597        Self::default()
598    }
599
600    pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
601        let _ = self.postgres_bytea_output.swap(Arc::new(value));
602    }
603
604    pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
605        self.postgres_bytea_output.load().clone()
606    }
607
608    pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
609        self.pg_datestyle_format.load().clone()
610    }
611
612    pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
613        self.pg_datestyle_format.swap(Arc::new((style, order)));
614    }
615}
616
617#[cfg(test)]
618mod test {
619    use common_catalog::consts::DEFAULT_CATALOG_NAME;
620
621    use super::*;
622    use crate::context::Channel;
623    use crate::Session;
624
625    #[test]
626    fn test_session() {
627        let session = Session::new(
628            Some("127.0.0.1:9000".parse().unwrap()),
629            Channel::Mysql,
630            Default::default(),
631            100,
632        );
633        // test user_info
634        assert_eq!(session.user_info().username(), "greptime");
635
636        // test channel
637        assert_eq!(session.conn_info().channel, Channel::Mysql);
638        let client_addr = session.conn_info().client_addr.as_ref().unwrap();
639        assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
640        assert_eq!(client_addr.port(), 9000);
641
642        assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
643        assert_eq!(100, session.process_id());
644    }
645
646    #[test]
647    fn test_context_db_string() {
648        let context = QueryContext::with("a0b1c2d3", "test");
649        assert_eq!("a0b1c2d3-test", context.get_db_string());
650
651        let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
652        assert_eq!("test", context.get_db_string());
653    }
654}