1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
35use crate::{MutableInner, ReadPreference};
36
37pub type QueryContextRef = Arc<QueryContext>;
38pub type ConnInfoRef = Arc<ConnInfo>;
39
40const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
41
42#[derive(Debug, Builder, Clone)]
43#[builder(pattern = "owned")]
44#[builder(build_fn(skip))]
45pub struct QueryContext {
46 current_catalog: String,
47 snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
51 sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
53 #[builder(default)]
55 mutable_session_data: Arc<RwLock<MutableInner>>,
56 #[builder(default)]
57 mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
58 sql_dialect: Arc<dyn Dialect + Send + Sync>,
59 #[builder(default)]
60 extensions: HashMap<String, String>,
61 #[builder(default)]
63 configuration_parameter: Arc<ConfigurationVariables>,
64 #[builder(default)]
66 channel: Channel,
67 #[builder(default)]
69 process_id: u32,
70 #[builder(default)]
72 conn_info: ConnInfo,
73}
74
75#[derive(Debug, Builder, Clone, Default)]
77pub struct QueryContextMutableFields {
78 warning: Option<String>,
79 explain_format: Option<String>,
81 explain_options: Option<ExplainOptions>,
83}
84
85impl Display for QueryContext {
86 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
87 write!(
88 f,
89 "QueryContext{{catalog: {}, schema: {}}}",
90 self.current_catalog(),
91 self.current_schema()
92 )
93 }
94}
95
96impl QueryContextBuilder {
97 pub fn current_schema(mut self, schema: String) -> Self {
98 if self.mutable_session_data.is_none() {
99 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
100 }
101
102 self.mutable_session_data
104 .as_mut()
105 .unwrap()
106 .write()
107 .unwrap()
108 .schema = schema;
109 self
110 }
111
112 pub fn timezone(mut self, timezone: Timezone) -> Self {
113 if self.mutable_session_data.is_none() {
114 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
115 }
116
117 self.mutable_session_data
118 .as_mut()
119 .unwrap()
120 .write()
121 .unwrap()
122 .timezone = timezone;
123 self
124 }
125
126 pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
127 self.mutable_query_context_data
128 .get_or_insert_default()
129 .write()
130 .unwrap()
131 .explain_options = explain_options;
132 self
133 }
134
135 pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
136 self.mutable_session_data
137 .get_or_insert_default()
138 .write()
139 .unwrap()
140 .read_preference = read_preference;
141 self
142 }
143}
144
145impl From<&RegionRequestHeader> for QueryContext {
146 fn from(value: &RegionRequestHeader) -> Self {
147 if let Some(ctx) = &value.query_context {
148 ctx.clone().into()
149 } else {
150 QueryContextBuilder::default().build()
151 }
152 }
153}
154
155impl From<api::v1::QueryContext> for QueryContext {
156 fn from(ctx: api::v1::QueryContext) -> Self {
157 let sequences = ctx.snapshot_seqs.as_ref();
158 QueryContextBuilder::default()
159 .current_catalog(ctx.current_catalog)
160 .current_schema(ctx.current_schema)
161 .timezone(parse_timezone(Some(&ctx.timezone)))
162 .extensions(ctx.extensions)
163 .channel(ctx.channel.into())
164 .snapshot_seqs(Arc::new(RwLock::new(
165 sequences
166 .map(|x| x.snapshot_seqs.clone())
167 .unwrap_or_default(),
168 )))
169 .sst_min_sequences(Arc::new(RwLock::new(
170 sequences
171 .map(|x| x.sst_min_sequences.clone())
172 .unwrap_or_default(),
173 )))
174 .explain_options(ctx.explain)
175 .build()
176 }
177}
178
179impl From<QueryContext> for api::v1::QueryContext {
180 fn from(
181 QueryContext {
182 current_catalog,
183 mutable_session_data: mutable_inner,
184 extensions,
185 channel,
186 snapshot_seqs,
187 sst_min_sequences,
188 mutable_query_context_data,
189 ..
190 }: QueryContext,
191 ) -> Self {
192 let explain = mutable_query_context_data.read().unwrap().explain_options;
193 let mutable_inner = mutable_inner.read().unwrap();
194 api::v1::QueryContext {
195 current_catalog,
196 current_schema: mutable_inner.schema.clone(),
197 timezone: mutable_inner.timezone.to_string(),
198 extensions,
199 channel: channel as u32,
200 snapshot_seqs: Some(api::v1::SnapshotSequences {
201 snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
202 sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
203 }),
204 explain,
205 }
206 }
207}
208
209impl From<&QueryContext> for api::v1::QueryContext {
210 fn from(ctx: &QueryContext) -> Self {
211 ctx.clone().into()
212 }
213}
214
215impl QueryContext {
216 pub fn arc() -> QueryContextRef {
217 Arc::new(QueryContextBuilder::default().build())
218 }
219
220 pub fn with(catalog: &str, schema: &str) -> QueryContext {
221 QueryContextBuilder::default()
222 .current_catalog(catalog.to_string())
223 .current_schema(schema.to_string())
224 .build()
225 }
226
227 pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
228 QueryContextBuilder::default()
229 .current_catalog(catalog.to_string())
230 .current_schema(schema.to_string())
231 .channel(channel)
232 .build()
233 }
234
235 pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
236 let (catalog, schema) = db_name
237 .map(|db| {
238 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
239 (catalog, schema)
240 })
241 .unwrap_or_else(|| {
242 (
243 DEFAULT_CATALOG_NAME.to_string(),
244 DEFAULT_SCHEMA_NAME.to_string(),
245 )
246 });
247 QueryContextBuilder::default()
248 .current_catalog(catalog)
249 .current_schema(schema.to_string())
250 .build()
251 }
252
253 pub fn current_schema(&self) -> String {
254 self.mutable_session_data.read().unwrap().schema.clone()
255 }
256
257 pub fn set_current_schema(&self, new_schema: &str) {
258 self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
259 }
260
261 pub fn current_catalog(&self) -> &str {
262 &self.current_catalog
263 }
264
265 pub fn set_current_catalog(&mut self, new_catalog: &str) {
266 self.current_catalog = new_catalog.to_string();
267 }
268
269 pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
270 &*self.sql_dialect
271 }
272
273 pub fn get_db_string(&self) -> String {
274 let catalog = self.current_catalog();
275 let schema = self.current_schema();
276 build_db_string(catalog, &schema)
277 }
278
279 pub fn timezone(&self) -> Timezone {
280 self.mutable_session_data.read().unwrap().timezone.clone()
281 }
282
283 pub fn set_timezone(&self, timezone: Timezone) {
284 self.mutable_session_data.write().unwrap().timezone = timezone;
285 }
286
287 pub fn read_preference(&self) -> ReadPreference {
288 self.mutable_session_data.read().unwrap().read_preference
289 }
290
291 pub fn set_read_preference(&self, read_preference: ReadPreference) {
292 self.mutable_session_data.write().unwrap().read_preference = read_preference;
293 }
294
295 pub fn current_user(&self) -> UserInfoRef {
296 self.mutable_session_data.read().unwrap().user_info.clone()
297 }
298
299 pub fn set_current_user(&self, user: UserInfoRef) {
300 self.mutable_session_data.write().unwrap().user_info = user;
301 }
302
303 pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
304 self.extensions.insert(key.into(), value.into());
305 }
306
307 pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
308 self.extensions.get(key.as_ref()).map(|v| v.as_str())
309 }
310
311 pub fn extensions(&self) -> HashMap<String, String> {
312 self.extensions.clone()
313 }
314
315 pub fn quote_style(&self) -> char {
317 if self.sql_dialect().is_delimited_identifier_start('"') {
318 '"'
319 } else if self.sql_dialect().is_delimited_identifier_start('\'') {
320 '\''
321 } else {
322 '`'
323 }
324 }
325
326 pub fn configuration_parameter(&self) -> &ConfigurationVariables {
327 &self.configuration_parameter
328 }
329
330 pub fn channel(&self) -> Channel {
331 self.channel
332 }
333
334 pub fn set_channel(&mut self, channel: Channel) {
335 self.channel = channel;
336 }
337
338 pub fn warning(&self) -> Option<String> {
339 self.mutable_query_context_data
340 .read()
341 .unwrap()
342 .warning
343 .clone()
344 }
345
346 pub fn set_warning(&self, msg: String) {
347 self.mutable_query_context_data.write().unwrap().warning = Some(msg);
348 }
349
350 pub fn explain_format(&self) -> Option<String> {
351 self.mutable_query_context_data
352 .read()
353 .unwrap()
354 .explain_format
355 .clone()
356 }
357
358 pub fn set_explain_format(&self, format: String) {
359 self.mutable_query_context_data
360 .write()
361 .unwrap()
362 .explain_format = Some(format);
363 }
364
365 pub fn explain_verbose(&self) -> bool {
366 self.mutable_query_context_data
367 .read()
368 .unwrap()
369 .explain_options
370 .map(|opts| opts.verbose)
371 .unwrap_or(false)
372 }
373
374 pub fn set_explain_verbose(&self, verbose: bool) {
375 self.mutable_query_context_data
376 .write()
377 .unwrap()
378 .explain_options
379 .get_or_insert_default()
380 .verbose = verbose;
381 }
382
383 pub fn query_timeout(&self) -> Option<Duration> {
384 self.mutable_session_data.read().unwrap().query_timeout
385 }
386
387 pub fn query_timeout_as_millis(&self) -> u128 {
388 let timeout = self.mutable_session_data.read().unwrap().query_timeout;
389 if let Some(t) = timeout {
390 return t.as_millis();
391 }
392 0
393 }
394
395 pub fn set_query_timeout(&self, timeout: Duration) {
396 self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
397 }
398
399 pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
400 let mut guard = self.mutable_session_data.write().unwrap();
401 guard.cursors.insert(name, Arc::new(rb));
402
403 let cursor_count = guard.cursors.len();
404 if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
405 warn!("Current connection has {} open cursors", cursor_count);
406 }
407 }
408
409 pub fn remove_cursor(&self, name: &str) {
410 let mut guard = self.mutable_session_data.write().unwrap();
411 guard.cursors.remove(name);
412 }
413
414 pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
415 let guard = self.mutable_session_data.read().unwrap();
416 let rb = guard.cursors.get(name);
417 rb.cloned()
418 }
419
420 pub fn snapshots(&self) -> HashMap<u64, u64> {
421 self.snapshot_seqs.read().unwrap().clone()
422 }
423
424 pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
425 self.snapshot_seqs.read().unwrap().get(®ion_id).cloned()
426 }
427
428 pub fn auto_string_to_numeric(&self) -> bool {
430 matches!(self.channel, Channel::Mysql)
431 }
432
433 pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
435 self.sst_min_sequences
436 .read()
437 .unwrap()
438 .get(®ion_id)
439 .copied()
440 }
441
442 pub fn process_id(&self) -> u32 {
443 self.process_id
444 }
445
446 pub fn conn_info(&self) -> &ConnInfo {
448 &self.conn_info
449 }
450}
451
452impl QueryContextBuilder {
453 pub fn build(self) -> QueryContext {
454 let channel = self.channel.unwrap_or_default();
455 QueryContext {
456 current_catalog: self
457 .current_catalog
458 .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
459 snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
460 sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
461 mutable_session_data: self.mutable_session_data.unwrap_or_default(),
462 mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
463 sql_dialect: self
464 .sql_dialect
465 .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
466 extensions: self.extensions.unwrap_or_default(),
467 configuration_parameter: self
468 .configuration_parameter
469 .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
470 channel,
471 process_id: self.process_id.unwrap_or_default(),
472 conn_info: self.conn_info.unwrap_or_default(),
473 }
474 }
475
476 pub fn set_extension(mut self, key: String, value: String) -> Self {
477 self.extensions
478 .get_or_insert_with(HashMap::new)
479 .insert(key, value);
480 self
481 }
482}
483
484#[derive(Debug, Clone, Default)]
485pub struct ConnInfo {
486 pub client_addr: Option<SocketAddr>,
487 pub channel: Channel,
488}
489
490impl Display for ConnInfo {
491 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
492 write!(
493 f,
494 "{}[{}]",
495 self.channel,
496 self.client_addr
497 .map(|addr| addr.to_string())
498 .as_deref()
499 .unwrap_or("unknown client addr")
500 )
501 }
502}
503
504impl ConnInfo {
505 pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
506 Self {
507 client_addr,
508 channel,
509 }
510 }
511}
512
513#[derive(Debug, PartialEq, Default, Clone, Copy)]
514#[repr(u8)]
515pub enum Channel {
516 #[default]
517 Unknown = 0,
518
519 Mysql = 1,
520 Postgres = 2,
521 Http = 3,
522 Prometheus = 4,
523 Otlp = 5,
524 Grpc = 6,
525 Influx = 7,
526 Opentsdb = 8,
527 Loki = 9,
528 Elasticsearch = 10,
529 Jaeger = 11,
530}
531
532impl From<u32> for Channel {
533 fn from(value: u32) -> Self {
534 match value {
535 1 => Self::Mysql,
536 2 => Self::Postgres,
537 3 => Self::Http,
538 4 => Self::Prometheus,
539 5 => Self::Otlp,
540 6 => Self::Grpc,
541 7 => Self::Influx,
542 8 => Self::Opentsdb,
543 9 => Self::Loki,
544 10 => Self::Elasticsearch,
545 11 => Self::Jaeger,
546 _ => Self::Unknown,
547 }
548 }
549}
550
551impl Channel {
552 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
553 match self {
554 Channel::Mysql => Arc::new(MySqlDialect {}),
555 Channel::Postgres => Arc::new(PostgreSqlDialect {}),
556 _ => Arc::new(GenericDialect {}),
557 }
558 }
559}
560
561impl Display for Channel {
562 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
563 match self {
564 Channel::Mysql => write!(f, "mysql"),
565 Channel::Postgres => write!(f, "postgres"),
566 Channel::Http => write!(f, "http"),
567 Channel::Prometheus => write!(f, "prometheus"),
568 Channel::Otlp => write!(f, "otlp"),
569 Channel::Grpc => write!(f, "grpc"),
570 Channel::Influx => write!(f, "influx"),
571 Channel::Opentsdb => write!(f, "opentsdb"),
572 Channel::Loki => write!(f, "loki"),
573 Channel::Elasticsearch => write!(f, "elasticsearch"),
574 Channel::Jaeger => write!(f, "jaeger"),
575 Channel::Unknown => write!(f, "unknown"),
576 }
577 }
578}
579
580#[derive(Default, Debug)]
581pub struct ConfigurationVariables {
582 postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
583 pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
584}
585
586impl Clone for ConfigurationVariables {
587 fn clone(&self) -> Self {
588 Self {
589 postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
590 pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
591 }
592 }
593}
594
595impl ConfigurationVariables {
596 pub fn new() -> Self {
597 Self::default()
598 }
599
600 pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
601 let _ = self.postgres_bytea_output.swap(Arc::new(value));
602 }
603
604 pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
605 self.postgres_bytea_output.load().clone()
606 }
607
608 pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
609 self.pg_datestyle_format.load().clone()
610 }
611
612 pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
613 self.pg_datestyle_format.swap(Arc::new((style, order)));
614 }
615}
616
617#[cfg(test)]
618mod test {
619 use common_catalog::consts::DEFAULT_CATALOG_NAME;
620
621 use super::*;
622 use crate::context::Channel;
623 use crate::Session;
624
625 #[test]
626 fn test_session() {
627 let session = Session::new(
628 Some("127.0.0.1:9000".parse().unwrap()),
629 Channel::Mysql,
630 Default::default(),
631 100,
632 );
633 assert_eq!(session.user_info().username(), "greptime");
635
636 assert_eq!(session.conn_info().channel, Channel::Mysql);
638 let client_addr = session.conn_info().client_addr.as_ref().unwrap();
639 assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
640 assert_eq!(client_addr.port(), 9000);
641
642 assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
643 assert_eq!(100, session.process_id());
644 }
645
646 #[test]
647 fn test_context_db_string() {
648 let context = QueryContext::with("a0b1c2d3", "test");
649 assert_eq!("a0b1c2d3-test", context.get_db_string());
650
651 let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
652 assert_eq!("test", context.get_db_string());
653 }
654}