1use std::collections::HashMap;
16use std::fmt::{Display, Formatter};
17use std::net::SocketAddr;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use api::v1::region::RegionRequestHeader;
22use api::v1::ExplainOptions;
23use arc_swap::ArcSwap;
24use auth::UserInfoRef;
25use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
27use common_recordbatch::cursor::RecordBatchStreamCursor;
28use common_telemetry::warn;
29use common_time::timezone::parse_timezone;
30use common_time::Timezone;
31use derive_builder::Builder;
32use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
33
34use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
35use crate::{MutableInner, ReadPreference};
36
37pub type QueryContextRef = Arc<QueryContext>;
38pub type ConnInfoRef = Arc<ConnInfo>;
39
40const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
41
42#[derive(Debug, Builder, Clone)]
43#[builder(pattern = "owned")]
44#[builder(build_fn(skip))]
45pub struct QueryContext {
46 current_catalog: String,
47 snapshot_seqs: Arc<RwLock<HashMap<u64, u64>>>,
51 sst_min_sequences: Arc<RwLock<HashMap<u64, u64>>>,
53 #[builder(default)]
55 mutable_session_data: Arc<RwLock<MutableInner>>,
56 #[builder(default)]
57 mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
58 sql_dialect: Arc<dyn Dialect + Send + Sync>,
59 #[builder(default)]
60 extensions: HashMap<String, String>,
61 #[builder(default)]
63 configuration_parameter: Arc<ConfigurationVariables>,
64 #[builder(default)]
66 channel: Channel,
67}
68
69#[derive(Debug, Builder, Clone, Default)]
71pub struct QueryContextMutableFields {
72 warning: Option<String>,
73 explain_format: Option<String>,
75 explain_options: Option<ExplainOptions>,
77}
78
79impl Display for QueryContext {
80 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81 write!(
82 f,
83 "QueryContext{{catalog: {}, schema: {}}}",
84 self.current_catalog(),
85 self.current_schema()
86 )
87 }
88}
89
90impl QueryContextBuilder {
91 pub fn current_schema(mut self, schema: String) -> Self {
92 if self.mutable_session_data.is_none() {
93 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
94 }
95
96 self.mutable_session_data
98 .as_mut()
99 .unwrap()
100 .write()
101 .unwrap()
102 .schema = schema;
103 self
104 }
105
106 pub fn timezone(mut self, timezone: Timezone) -> Self {
107 if self.mutable_session_data.is_none() {
108 self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
109 }
110
111 self.mutable_session_data
112 .as_mut()
113 .unwrap()
114 .write()
115 .unwrap()
116 .timezone = timezone;
117 self
118 }
119
120 pub fn explain_options(mut self, explain_options: Option<ExplainOptions>) -> Self {
121 self.mutable_query_context_data
122 .get_or_insert_default()
123 .write()
124 .unwrap()
125 .explain_options = explain_options;
126 self
127 }
128
129 pub fn read_preference(mut self, read_preference: ReadPreference) -> Self {
130 self.mutable_session_data
131 .get_or_insert_default()
132 .write()
133 .unwrap()
134 .read_preference = read_preference;
135 self
136 }
137}
138
139impl From<&RegionRequestHeader> for QueryContext {
140 fn from(value: &RegionRequestHeader) -> Self {
141 if let Some(ctx) = &value.query_context {
142 ctx.clone().into()
143 } else {
144 QueryContextBuilder::default().build()
145 }
146 }
147}
148
149impl From<api::v1::QueryContext> for QueryContext {
150 fn from(ctx: api::v1::QueryContext) -> Self {
151 let sequences = ctx.snapshot_seqs.as_ref();
152 QueryContextBuilder::default()
153 .current_catalog(ctx.current_catalog)
154 .current_schema(ctx.current_schema)
155 .timezone(parse_timezone(Some(&ctx.timezone)))
156 .extensions(ctx.extensions)
157 .channel(ctx.channel.into())
158 .snapshot_seqs(Arc::new(RwLock::new(
159 sequences
160 .map(|x| x.snapshot_seqs.clone())
161 .unwrap_or_default(),
162 )))
163 .sst_min_sequences(Arc::new(RwLock::new(
164 sequences
165 .map(|x| x.sst_min_sequences.clone())
166 .unwrap_or_default(),
167 )))
168 .explain_options(ctx.explain)
169 .build()
170 }
171}
172
173impl From<QueryContext> for api::v1::QueryContext {
174 fn from(
175 QueryContext {
176 current_catalog,
177 mutable_session_data: mutable_inner,
178 extensions,
179 channel,
180 snapshot_seqs,
181 sst_min_sequences,
182 mutable_query_context_data,
183 ..
184 }: QueryContext,
185 ) -> Self {
186 let explain = mutable_query_context_data.read().unwrap().explain_options;
187 let mutable_inner = mutable_inner.read().unwrap();
188 api::v1::QueryContext {
189 current_catalog,
190 current_schema: mutable_inner.schema.clone(),
191 timezone: mutable_inner.timezone.to_string(),
192 extensions,
193 channel: channel as u32,
194 snapshot_seqs: Some(api::v1::SnapshotSequences {
195 snapshot_seqs: snapshot_seqs.read().unwrap().clone(),
196 sst_min_sequences: sst_min_sequences.read().unwrap().clone(),
197 }),
198 explain,
199 }
200 }
201}
202
203impl From<&QueryContext> for api::v1::QueryContext {
204 fn from(ctx: &QueryContext) -> Self {
205 ctx.clone().into()
206 }
207}
208
209impl QueryContext {
210 pub fn arc() -> QueryContextRef {
211 Arc::new(QueryContextBuilder::default().build())
212 }
213
214 pub fn with(catalog: &str, schema: &str) -> QueryContext {
215 QueryContextBuilder::default()
216 .current_catalog(catalog.to_string())
217 .current_schema(schema.to_string())
218 .build()
219 }
220
221 pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
222 QueryContextBuilder::default()
223 .current_catalog(catalog.to_string())
224 .current_schema(schema.to_string())
225 .channel(channel)
226 .build()
227 }
228
229 pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
230 let (catalog, schema) = db_name
231 .map(|db| {
232 let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
233 (catalog, schema)
234 })
235 .unwrap_or_else(|| {
236 (
237 DEFAULT_CATALOG_NAME.to_string(),
238 DEFAULT_SCHEMA_NAME.to_string(),
239 )
240 });
241 QueryContextBuilder::default()
242 .current_catalog(catalog)
243 .current_schema(schema.to_string())
244 .build()
245 }
246
247 pub fn current_schema(&self) -> String {
248 self.mutable_session_data.read().unwrap().schema.clone()
249 }
250
251 pub fn set_current_schema(&self, new_schema: &str) {
252 self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
253 }
254
255 pub fn current_catalog(&self) -> &str {
256 &self.current_catalog
257 }
258
259 pub fn set_current_catalog(&mut self, new_catalog: &str) {
260 self.current_catalog = new_catalog.to_string();
261 }
262
263 pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
264 &*self.sql_dialect
265 }
266
267 pub fn get_db_string(&self) -> String {
268 let catalog = self.current_catalog();
269 let schema = self.current_schema();
270 build_db_string(catalog, &schema)
271 }
272
273 pub fn timezone(&self) -> Timezone {
274 self.mutable_session_data.read().unwrap().timezone.clone()
275 }
276
277 pub fn set_timezone(&self, timezone: Timezone) {
278 self.mutable_session_data.write().unwrap().timezone = timezone;
279 }
280
281 pub fn read_preference(&self) -> ReadPreference {
282 self.mutable_session_data.read().unwrap().read_preference
283 }
284
285 pub fn set_read_preference(&self, read_preference: ReadPreference) {
286 self.mutable_session_data.write().unwrap().read_preference = read_preference;
287 }
288
289 pub fn current_user(&self) -> UserInfoRef {
290 self.mutable_session_data.read().unwrap().user_info.clone()
291 }
292
293 pub fn set_current_user(&self, user: UserInfoRef) {
294 self.mutable_session_data.write().unwrap().user_info = user;
295 }
296
297 pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
298 self.extensions.insert(key.into(), value.into());
299 }
300
301 pub fn extension<S: AsRef<str>>(&self, key: S) -> Option<&str> {
302 self.extensions.get(key.as_ref()).map(|v| v.as_str())
303 }
304
305 pub fn extensions(&self) -> HashMap<String, String> {
306 self.extensions.clone()
307 }
308
309 pub fn quote_style(&self) -> char {
311 if self.sql_dialect().is_delimited_identifier_start('"') {
312 '"'
313 } else if self.sql_dialect().is_delimited_identifier_start('\'') {
314 '\''
315 } else {
316 '`'
317 }
318 }
319
320 pub fn configuration_parameter(&self) -> &ConfigurationVariables {
321 &self.configuration_parameter
322 }
323
324 pub fn channel(&self) -> Channel {
325 self.channel
326 }
327
328 pub fn set_channel(&mut self, channel: Channel) {
329 self.channel = channel;
330 }
331
332 pub fn warning(&self) -> Option<String> {
333 self.mutable_query_context_data
334 .read()
335 .unwrap()
336 .warning
337 .clone()
338 }
339
340 pub fn set_warning(&self, msg: String) {
341 self.mutable_query_context_data.write().unwrap().warning = Some(msg);
342 }
343
344 pub fn explain_format(&self) -> Option<String> {
345 self.mutable_query_context_data
346 .read()
347 .unwrap()
348 .explain_format
349 .clone()
350 }
351
352 pub fn set_explain_format(&self, format: String) {
353 self.mutable_query_context_data
354 .write()
355 .unwrap()
356 .explain_format = Some(format);
357 }
358
359 pub fn explain_verbose(&self) -> bool {
360 self.mutable_query_context_data
361 .read()
362 .unwrap()
363 .explain_options
364 .map(|opts| opts.verbose)
365 .unwrap_or(false)
366 }
367
368 pub fn set_explain_verbose(&self, verbose: bool) {
369 self.mutable_query_context_data
370 .write()
371 .unwrap()
372 .explain_options
373 .get_or_insert_default()
374 .verbose = verbose;
375 }
376
377 pub fn query_timeout(&self) -> Option<Duration> {
378 self.mutable_session_data.read().unwrap().query_timeout
379 }
380
381 pub fn query_timeout_as_millis(&self) -> u128 {
382 let timeout = self.mutable_session_data.read().unwrap().query_timeout;
383 if let Some(t) = timeout {
384 return t.as_millis();
385 }
386 0
387 }
388
389 pub fn set_query_timeout(&self, timeout: Duration) {
390 self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
391 }
392
393 pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
394 let mut guard = self.mutable_session_data.write().unwrap();
395 guard.cursors.insert(name, Arc::new(rb));
396
397 let cursor_count = guard.cursors.len();
398 if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
399 warn!("Current connection has {} open cursors", cursor_count);
400 }
401 }
402
403 pub fn remove_cursor(&self, name: &str) {
404 let mut guard = self.mutable_session_data.write().unwrap();
405 guard.cursors.remove(name);
406 }
407
408 pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
409 let guard = self.mutable_session_data.read().unwrap();
410 let rb = guard.cursors.get(name);
411 rb.cloned()
412 }
413
414 pub fn snapshots(&self) -> HashMap<u64, u64> {
415 self.snapshot_seqs.read().unwrap().clone()
416 }
417
418 pub fn get_snapshot(&self, region_id: u64) -> Option<u64> {
419 self.snapshot_seqs.read().unwrap().get(®ion_id).cloned()
420 }
421
422 pub fn auto_string_to_numeric(&self) -> bool {
424 matches!(self.channel, Channel::Mysql)
425 }
426
427 pub fn sst_min_sequence(&self, region_id: u64) -> Option<u64> {
429 self.sst_min_sequences
430 .read()
431 .unwrap()
432 .get(®ion_id)
433 .copied()
434 }
435}
436
437impl QueryContextBuilder {
438 pub fn build(self) -> QueryContext {
439 let channel = self.channel.unwrap_or_default();
440 QueryContext {
441 current_catalog: self
442 .current_catalog
443 .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
444 snapshot_seqs: self.snapshot_seqs.unwrap_or_default(),
445 sst_min_sequences: self.sst_min_sequences.unwrap_or_default(),
446 mutable_session_data: self.mutable_session_data.unwrap_or_default(),
447 mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
448 sql_dialect: self
449 .sql_dialect
450 .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
451 extensions: self.extensions.unwrap_or_default(),
452 configuration_parameter: self
453 .configuration_parameter
454 .unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
455 channel,
456 }
457 }
458
459 pub fn set_extension(mut self, key: String, value: String) -> Self {
460 self.extensions
461 .get_or_insert_with(HashMap::new)
462 .insert(key, value);
463 self
464 }
465}
466
467#[derive(Debug)]
468pub struct ConnInfo {
469 pub client_addr: Option<SocketAddr>,
470 pub channel: Channel,
471}
472
473impl Display for ConnInfo {
474 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
475 write!(
476 f,
477 "{}[{}]",
478 self.channel,
479 self.client_addr
480 .map(|addr| addr.to_string())
481 .as_deref()
482 .unwrap_or("unknown client addr")
483 )
484 }
485}
486
487impl ConnInfo {
488 pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
489 Self {
490 client_addr,
491 channel,
492 }
493 }
494}
495
496#[derive(Debug, PartialEq, Default, Clone, Copy)]
497#[repr(u8)]
498pub enum Channel {
499 #[default]
500 Unknown = 0,
501
502 Mysql = 1,
503 Postgres = 2,
504 Http = 3,
505 Prometheus = 4,
506 Otlp = 5,
507 Grpc = 6,
508 Influx = 7,
509 Opentsdb = 8,
510 Loki = 9,
511 Elasticsearch = 10,
512 Jaeger = 11,
513}
514
515impl From<u32> for Channel {
516 fn from(value: u32) -> Self {
517 match value {
518 1 => Self::Mysql,
519 2 => Self::Postgres,
520 3 => Self::Http,
521 4 => Self::Prometheus,
522 5 => Self::Otlp,
523 6 => Self::Grpc,
524 7 => Self::Influx,
525 8 => Self::Opentsdb,
526 9 => Self::Loki,
527 10 => Self::Elasticsearch,
528 11 => Self::Jaeger,
529 _ => Self::Unknown,
530 }
531 }
532}
533
534impl Channel {
535 pub fn dialect(&self) -> Arc<dyn Dialect + Send + Sync> {
536 match self {
537 Channel::Mysql => Arc::new(MySqlDialect {}),
538 Channel::Postgres => Arc::new(PostgreSqlDialect {}),
539 _ => Arc::new(GenericDialect {}),
540 }
541 }
542}
543
544impl Display for Channel {
545 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
546 match self {
547 Channel::Mysql => write!(f, "mysql"),
548 Channel::Postgres => write!(f, "postgres"),
549 Channel::Http => write!(f, "http"),
550 Channel::Prometheus => write!(f, "prometheus"),
551 Channel::Otlp => write!(f, "otlp"),
552 Channel::Grpc => write!(f, "grpc"),
553 Channel::Influx => write!(f, "influx"),
554 Channel::Opentsdb => write!(f, "opentsdb"),
555 Channel::Loki => write!(f, "loki"),
556 Channel::Elasticsearch => write!(f, "elasticsearch"),
557 Channel::Jaeger => write!(f, "jaeger"),
558 Channel::Unknown => write!(f, "unknown"),
559 }
560 }
561}
562
563#[derive(Default, Debug)]
564pub struct ConfigurationVariables {
565 postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
566 pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
567}
568
569impl Clone for ConfigurationVariables {
570 fn clone(&self) -> Self {
571 Self {
572 postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
573 pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
574 }
575 }
576}
577
578impl ConfigurationVariables {
579 pub fn new() -> Self {
580 Self::default()
581 }
582
583 pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
584 let _ = self.postgres_bytea_output.swap(Arc::new(value));
585 }
586
587 pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
588 self.postgres_bytea_output.load().clone()
589 }
590
591 pub fn pg_datetime_style(&self) -> Arc<(PGDateTimeStyle, PGDateOrder)> {
592 self.pg_datestyle_format.load().clone()
593 }
594
595 pub fn set_pg_datetime_style(&self, style: PGDateTimeStyle, order: PGDateOrder) {
596 self.pg_datestyle_format.swap(Arc::new((style, order)));
597 }
598}
599
600#[cfg(test)]
601mod test {
602 use common_catalog::consts::DEFAULT_CATALOG_NAME;
603
604 use super::*;
605 use crate::context::Channel;
606 use crate::Session;
607
608 #[test]
609 fn test_session() {
610 let session = Session::new(
611 Some("127.0.0.1:9000".parse().unwrap()),
612 Channel::Mysql,
613 Default::default(),
614 );
615 assert_eq!(session.user_info().username(), "greptime");
617
618 assert_eq!(session.conn_info().channel, Channel::Mysql);
620 let client_addr = session.conn_info().client_addr.as_ref().unwrap();
621 assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
622 assert_eq!(client_addr.port(), 9000);
623
624 assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
625 }
626
627 #[test]
628 fn test_context_db_string() {
629 let context = QueryContext::with("a0b1c2d3", "test");
630 assert_eq!("a0b1c2d3-test", context.get_db_string());
631
632 let context = QueryContext::with(DEFAULT_CATALOG_NAME, "test");
633 assert_eq!("test", context.get_db_string());
634 }
635}