1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::protocol_ctx::ProtocolCtx;
35use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
36use crate::{MutableInner, ReadPreference};
37
38pub type QueryContextRef = Arc<QueryContext>;
39pub type ConnInfoRef = Arc<ConnInfo>;
40
41const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
42
43#[derive(Debug, Builder, Clone)]
44#[builder(pattern = "owned")]
45#[builder(build_fn(skip))]
46pub struct QueryContext {
47 current_catalog: String,
48 snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
52 sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
54 #[builder(default)]
56 mutable_session_data: Arc<RwLock<MutableInner>>,
57 #[builder(default)]
58 mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
59 sql_dialect: Arc<dyn Dialect + Send + Sync>,
60 #[builder(default)]
61 extensions: HashMap<String, String>,
62 #[builder(default)]
64 configuration_parameter: Arc<ConfigurationVariables>,
65 #[builder(default)]
67 channel: Channel,
68 #[builder(default)]
70 process_id: u32,
71 #[builder(default)]
73 conn_info: ConnInfo,
74 #[builder(default)]
76 protocol_ctx: ProtocolCtx,
77}
78
79#[derive(Debug, Builder, Clone, Default)]
81pub struct QueryContextMutableFields {
82 warning: Option<String>,
83 explain_format: Option<String>,
85 explain_options: Option<ExplainOptions>,
87}
88
89impl Display for QueryContext {
90 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
91 write!(
92 f,
93 "QueryContext{{catalog: {}, schema: {}}}",
94 self.current_catalog(),
95 self.current_schema()
96 )
97 }
98}
99
100impl QueryContextBuilder {
101 pub fn current_schema(mut self, schema: String) -> Self {
102 if self.mutable_session_data.is_none() {
103 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
104 }
105
106 self.mutable_session_data
108 .as_mut()
109 .unwrap()
110 .write()
111 .unwrap()
112 .schema = schema;
113 self
114 }
115
116 pub fn timezone(mut self, timezone: Timezone) -> Self {
117 if self.mutable_session_data.is_none() {
118 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
119 }
120
121 self.mutable_session_data
122 .as_mut()
123 .unwrap()
124 .write()
125 .unwrap()
126 .timezone = timezone;
127 self
128 }
129
130 pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
131 self.mutable_query_context_data
132 .get_or_insert_default()
133 .write()
134 .unwrap()
135 .explain_options = explain_options;
136 self
137 }
138
139 pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
140 self.mutable_session_data
141 .get_or_insert_default()
142 .write()
143 .unwrap()
144 .read_preference = read_preference;
145 self
146 }
147}
148
149impl From<&RegionRequestHeader> for QueryContext {
150 fn from(value: &RegionRequestHeader) -> Self {
151 if let Some(ctx) = &value.query_context {
152 ctx.clone().into()
153 } else {
154 QueryContextBuilder::default().build()
155 }
156 }
157}
158
159impl From<api::v1::QueryContext> for QueryContext {
160 fn from(ctx: api::v1::QueryContext) -> Self {
161 let sequences = ctx.snapshot_seqs.as_ref();
162 QueryContextBuilder::default()
163 .current_catalog(ctx.current_catalog)
164 .current_schema(ctx.current_schema)
165 .timezone(parse_timezone(Some(&ctx.timezone)))
166 .extensions(ctx.extensions)
167 .channel(ctx.channel.into())
168 .snapshot_seqs(Arc::new(RwLock::new(
169 sequences
170 .map(|x| x.snapshot_seqs.clone())
171 .unwrap_or_default(),
172 )))
173 .sst_min_sequences(Arc::new(RwLock::new(
174 sequences
175 .map(|x| x.sst_min_sequences.clone())
176 .unwrap_or_default(),
177 )))
178 .explain_options(ctx.explain)
179 .build()
180 }
181}
182
183impl From<QueryContext> for api::v1::QueryContext {
184 fn from(
185 QueryContext {
186 current_catalog,
187 mutable_session_data: mutable_inner,
188 extensions,
189 channel,
190 snapshot_seqs,
191 sst_min_sequences,
192 mutable_query_context_data,
193 ..
194 }: QueryContext,
195 ) -> Self {
196 let explain = mutable_query_context_data.read().unwrap().explain_options;
197 let mutable_inner = mutable_inner.read().unwrap();
198 api::v1::QueryContext {
199 current_catalog,
200 current_schema: mutable_inner.schema.clone(),
201 timezone: mutable_inner.timezone.to_string(),
202 extensions,
203 channel: channel as u32,
204 snapshot_seqs: Some(api::v1::SnapshotSequences {
205 snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
206 sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
207 }),
208 explain,
209 }
210 }
211}
212
213impl From<&QueryContext> for api::v1::QueryContext {
214 fn from(ctx: &QueryContext) -> Self {
215 ctx.clone().into()
216 }
217}
218
219impl QueryContext {
220 pub fn arc() -> QueryContextRef {
221 Arc::new(QueryContextBuilder::default().build())
222 }
223
224 pub fn with(catalog: &str, schema: &str) -> QueryContext {
225 QueryContextBuilder::default()
226 .current_catalog(catalog.to_string())
227 .current_schema(schema.to_string())
228 .build()
229 }
230
231 pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
232 QueryContextBuilder::default()
233 .current_catalog(catalog.to_string())
234 .current_schema(schema.to_string())
235 .channel(channel)
236 .build()
237 }
238
239 pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
240 let (catalog, schema) = db_name
241 .map(|db| {
242 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
243 (catalog, schema)
244 })
245 .unwrap_or_else(|| {
246 (
247 DEFAULT_CATALOG_NAME.to_string(),
248 DEFAULT_SCHEMA_NAME.to_string(),
249 )
250 });
251 QueryContextBuilder::default()
252 .current_catalog(catalog)
253 .current_schema(schema.to_string())
254 .build()
255 }
256
257 pub fn current_schema(&self) -> String {
258 self.mutable_session_data.read().unwrap().schema.clone()
259 }
260
261 pub fn set_current_schema(&self, new_schema: &str) {
262 self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
263 }
264
265 pub fn current_catalog(&self) -> &str {
266 &self.current_catalog
267 }
268
269 pub fn set_current_catalog(&mut self, new_catalog: &str) {
270 self.current_catalog = new_catalog.to_string();
271 }
272
273 pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
274 &*self.sql_dialect
275 }
276
277 pub fn get_db_string(&self) -> String {
278 let catalog = self.current_catalog();
279 let schema = self.current_schema();
280 build_db_string(catalog, &schema)
281 }
282
283 pub fn timezone(&self) -> Timezone {
284 self.mutable_session_data.read().unwrap().timezone.clone()
285 }
286
287 pub fn set_timezone(&self, timezone: Timezone) {
288 self.mutable_session_data.write().unwrap().timezone = timezone;
289 }
290
291 pub fn read_preference(&self) -> ReadPreference {
292 self.mutable_session_data.read().unwrap().read_preference
293 }
294
295 pub fn set_read_preference(&self, read_preference: ReadPreference) {
296 self.mutable_session_data.write().unwrap().read_preference = read_preference;
297 }
298
299 pub fn current_user(&self) -> UserInfoRef {
300 self.mutable_session_data.read().unwrap().user_info.clone()
301 }
302
303 pub fn set_current_user(&self, user: UserInfoRef) {
304 self.mutable_session_data.write().unwrap().user_info = user;
305 }
306
307 pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
308 self.extensions.insert(key.into(), value.into());
309 }
310
311 pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
312 self.extensions.get(key.as_ref()).map(|v| v.as_str())
313 }
314
315 pub fn extensions(&self) -> HashMap<String, String> {
316 self.extensions.clone()
317 }
318
319 pub fn quote_style(&self) -> char {
321 if self.sql_dialect().is_delimited_identifier_start('"') {
322 '"'
323 } else if self.sql_dialect().is_delimited_identifier_start('\'') {
324 '\''
325 } else {
326 '`'
327 }
328 }
329
330 pub fn configuration_parameter(&self) -> &ConfigurationVariables {
331 &self.configuration_parameter
332 }
333
334 pub fn channel(&self) -> Channel {
335 self.channel
336 }
337
338 pub fn set_channel(&mut self, channel: Channel) {
339 self.channel = channel;
340 }
341
342 pub fn warning(&self) -> Option<String> {
343 self.mutable_query_context_data
344 .read()
345 .unwrap()
346 .warning
347 .clone()
348 }
349
350 pub fn set_warning(&self, msg: String) {
351 self.mutable_query_context_data.write().unwrap().warning = Some(msg);
352 }
353
354 pub fn explain_format(&self) -> Option<String> {
355 self.mutable_query_context_data
356 .read()
357 .unwrap()
358 .explain_format
359 .clone()
360 }
361
362 pub fn set_explain_format(&self, format: String) {
363 self.mutable_query_context_data
364 .write()
365 .unwrap()
366 .explain_format = Some(format);
367 }
368
369 pub fn explain_verbose(&self) -> bool {
370 self.mutable_query_context_data
371 .read()
372 .unwrap()
373 .explain_options
374 .map(|opts| opts.verbose)
375 .unwrap_or(false)
376 }
377
378 pub fn set_explain_verbose(&self, verbose: bool) {
379 self.mutable_query_context_data
380 .write()
381 .unwrap()
382 .explain_options
383 .get_or_insert_default()
384 .verbose = verbose;
385 }
386
387 pub fn query_timeout(&self) -> Option<Duration> {
388 self.mutable_session_data.read().unwrap().query_timeout
389 }
390
391 pub fn query_timeout_as_millis(&self) -> u128 {
392 let timeout = self.mutable_session_data.read().unwrap().query_timeout;
393 if let Some(t) = timeout {
394 return t.as_millis();
395 }
396 0
397 }
398
399 pub fn set_query_timeout(&self, timeout: Duration) {
400 self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
401 }
402
403 pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
404 let mut guard = self.mutable_session_data.write().unwrap();
405 guard.cursors.insert(name, Arc::new(rb));
406
407 let cursor_count = guard.cursors.len();
408 if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
409 warn!("Current connection has {} open cursors", cursor_count);
410 }
411 }
412
413 pub fn remove_cursor(&self, name: &str) {
414 let mut guard = self.mutable_session_data.write().unwrap();
415 guard.cursors.remove(name);
416 }
417
418 pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
419 let guard = self.mutable_session_data.read().unwrap();
420 let rb = guard.cursors.get(name);
421 rb.cloned()
422 }
423
424 pub fn snapshots(&self) -> HashMap<u64, u64> {
425 self.snapshot_seqs.read().unwrap().clone()
426 }
427
428 pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
429 self.snapshot_seqs.read().unwrap().get(®ion_id).cloned()
430 }
431
432 pub fn auto_string_to_numeric(&self) -> bool {
434 matches!(self.channel, Channel::Mysql)
435 }
436
437 pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
439 self.sst_min_sequences
440 .read()
441 .unwrap()
442 .get(®ion_id)
443 .copied()
444 }
445
446 pub fn process_id(&self) -> u32 {
447 self.process_id
448 }
449
450 pub fn conn_info(&self) -> &ConnInfo {
452 &self.conn_info
453 }
454
455 pub fn protocol_ctx(&self) -> &ProtocolCtx {
456 &self.protocol_ctx
457 }
458
459 pub fn set_protocol_ctx(&mut self, protocol_ctx: ProtocolCtx) {
460 self.protocol_ctx = protocol_ctx;
461 }
462}
463
464impl QueryContextBuilder {
465 pub fn build(self) -> QueryContext {
466 let channel = self.channel.unwrap_or_default();
467 QueryContext {
468 current_catalog: self
469 .current_catalog
470 .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
471 snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
472 sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
473 mutable_session_data: self.mutable_session_data.unwrap_or_default(),
474 mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
475 sql_dialect: self
476 .sql_dialect
477 .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
478 extensions: self.extensions.unwrap_or_default(),
479 configuration_parameter: self
480 .configuration_parameter
481 .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
482 channel,
483 process_id: self.process_id.unwrap_or_default(),
484 conn_info: self.conn_info.unwrap_or_default(),
485 protocol_ctx: self.protocol_ctx.unwrap_or_default(),
486 }
487 }
488
489 pub fn set_extension(mut self, key: String, value: String) -> Self {
490 self.extensions
491 .get_or_insert_with(HashMap::new)
492 .insert(key, value);
493 self
494 }
495}
496
497#[derive(Debug, Clone, Default)]
498pub struct ConnInfo {
499 pub client_addr: Option<SocketAddr>,
500 pub channel: Channel,
501}
502
503impl Display for ConnInfo {
504 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
505 write!(
506 f,
507 "{}[{}]",
508 self.channel,
509 self.client_addr
510 .map(|addr| addr.to_string())
511 .as_deref()
512 .unwrap_or("unknown client addr")
513 )
514 }
515}
516
517impl ConnInfo {
518 pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
519 Self {
520 client_addr,
521 channel,
522 }
523 }
524}
525
526#[derive(Debug, PartialEq, Default, Clone, Copy)]
527#[repr(u8)]
528pub enum Channel {
529 #[default]
530 Unknown = 0,
531
532 Mysql = 1,
533 Postgres = 2,
534 HttpSql = 3,
535 Prometheus = 4,
536 Otlp = 5,
537 Grpc = 6,
538 Influx = 7,
539 Opentsdb = 8,
540 Loki = 9,
541 Elasticsearch = 10,
542 Jaeger = 11,
543 Log = 12,
544 Promql = 13,
545}
546
547impl From<u32> for Channel {
548 fn from(value: u32) -> Self {
549 match value {
550 1 => Self::Mysql,
551 2 => Self::Postgres,
552 3 => Self::HttpSql,
553 4 => Self::Prometheus,
554 5 => Self::Otlp,
555 6 => Self::Grpc,
556 7 => Self::Influx,
557 8 => Self::Opentsdb,
558 9 => Self::Loki,
559 10 => Self::Elasticsearch,
560 11 => Self::Jaeger,
561 12 => Self::Log,
562 13 => Self::Promql,
563 _ => Self::Unknown,
564 }
565 }
566}
567
568impl Channel {
569 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
570 match self {
571 Channel::Mysql => Arc::new(MySqlDialect {}),
572 Channel::Postgres => Arc::new(PostgreSqlDialect {}),
573 _ => Arc::new(GenericDialect {}),
574 }
575 }
576}
577
578impl Display for Channel {
579 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
580 write!(f, "{}", self.as_ref())
581 }
582}
583
584impl AsRef<str> for Channel {
585 fn as_ref(&self) -> &str {
586 match self {
587 Channel::Mysql => "mysql",
588 Channel::Postgres => "postgres",
589 Channel::HttpSql => "httpsql",
590 Channel::Prometheus => "prometheus",
591 Channel::Otlp => "otlp",
592 Channel::Grpc => "grpc",
593 Channel::Influx => "influx",
594 Channel::Opentsdb => "opentsdb",
595 Channel::Loki => "loki",
596 Channel::Elasticsearch => "elasticsearch",
597 Channel::Jaeger => "jaeger",
598 Channel::Log => "log",
599 Channel::Promql => "promql",
600 Channel::Unknown => "unknown",
601 }
602 }
603}
604
605#[derive(Default, Debug)]
606pub struct ConfigurationVariables {
607 postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
608 pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
609 allow_query_fallback: ArcSwap<bool>,
610}
611
612impl Clone for ConfigurationVariables {
613 fn clone(&self) -> Self {
614 Self {
615 postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
616 pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
617 allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
618 }
619 }
620}
621
622impl ConfigurationVariables {
623 pub fn new() -> Self {
624 Self::default()
625 }
626
627 pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
628 let _ = self.postgres_bytea_output.swap(Arc::new(value));
629 }
630
631 pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
632 self.postgres_bytea_output.load().clone()
633 }
634
635 pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
636 self.pg_datestyle_format.load().clone()
637 }
638
639 pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
640 self.pg_datestyle_format.swap(Arc::new((style, order)));
641 }
642
643 pub fn allow_query_fallback(&self) -> bool {
644 **self.allow_query_fallback.load()
645 }
646
647 pub fn set_allow_query_fallback(&self, allow: bool) {
648 self.allow_query_fallback.swap(Arc::new(allow));
649 }
650}
651
652#[cfg(test)]
653mod test {
654 use common_catalog::consts::DEFAULT_CATALOG_NAME;
655
656 use super::*;
657 use crate::context::Channel;
658 use crate::Session;
659
660 #[test]
661 fn test_session() {
662 let session = Session::new(
663 Some("127.0.0.1:9000".parse().unwrap()),
664 Channel::Mysql,
665 Default::default(),
666 100,
667 );
668 assert_eq!(session.user_info().username(), "greptime");
670
671 assert_eq!(session.conn_info().channel, Channel::Mysql);
673 let client_addr = session.conn_info().client_addr.as_ref().unwrap();
674 assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
675 assert_eq!(client_addr.port(), 9000);
676
677 assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
678 assert_eq!(100, session.process_id());
679 }
680
681 #[test]
682 fn test_context_db_string() {
683 let context = QueryContext::with("a0b1c2d3", "test");
684 assert_eq!("a0b1c2d3-test", context.get_db_string());
685
686 let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
687 assert_eq!("test", context.get_db_string());
688 }
689}