catalog/
process_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt::{Debug, Display, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant, UNIX_EPOCH};
21
22use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
23use common_base::cancellation::CancellationHandle;
24use common_event_recorder::EventRecorderRef;
25use common_frontend::selector::{FrontendSelector, MetaClientSelector};
26use common_frontend::slow_query_event::SlowQueryEvent;
27use common_telemetry::logging::SlowQueriesRecordType;
28use common_telemetry::{debug, info, slow, warn};
29use common_time::util::current_time_millis;
30use meta_client::MetaClientRef;
31use promql_parser::parser::EvalStmt;
32use rand::random;
33use snafu::{OptionExt, ResultExt, ensure};
34use sql::statements::statement::Statement;
35
36use crate::error;
37use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
38
39pub type ProcessId = u32;
40pub type ProcessManagerRef = Arc<ProcessManager>;
41
42/// Query process manager.
43pub struct ProcessManager {
44    /// Local frontend server address,
45    server_addr: String,
46    /// Next process id for local queries.
47    next_id: AtomicU32,
48    /// Running process per catalog.
49    catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
50    /// Frontend selector to locate frontend nodes.
51    frontend_selector: Option<MetaClientSelector>,
52}
53
54/// Represents a parsed query statement, functionally equivalent to [query::parser::QueryStatement].
55/// This enum is defined here to avoid cyclic dependencies with the query parser module.
56#[derive(Debug, Clone)]
57pub enum QueryStatement {
58    Sql(Statement),
59    // The optional string is the alias of the PromQL query.
60    Promql(EvalStmt, Option<String>),
61}
62
63impl Display for QueryStatement {
64    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65        match self {
66            QueryStatement::Sql(stmt) => write!(f, "{}", stmt),
67            QueryStatement::Promql(eval_stmt, alias) => {
68                if let Some(alias) = alias {
69                    write!(f, "{} AS {}", eval_stmt, alias)
70                } else {
71                    write!(f, "{}", eval_stmt)
72                }
73            }
74        }
75    }
76}
77
78impl ProcessManager {
79    /// Create a [ProcessManager] instance with server address and kv client.
80    pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
81        let frontend_selector = meta_client.map(MetaClientSelector::new);
82        Self {
83            server_addr,
84            next_id: Default::default(),
85            catalogs: Default::default(),
86            frontend_selector,
87        }
88    }
89}
90
91impl ProcessManager {
92    /// Registers a submitted query. Use the provided id if present.
93    #[must_use]
94    pub fn register_query(
95        self: &Arc<Self>,
96        catalog: String,
97        schemas: Vec<String>,
98        query: String,
99        client: String,
100        query_id: Option<ProcessId>,
101        _slow_query_timer: Option<SlowQueryTimer>,
102    ) -> Ticket {
103        let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
104        let process = ProcessInfo {
105            id,
106            catalog: catalog.clone(),
107            schemas,
108            query,
109            start_timestamp: current_time_millis(),
110            client,
111            frontend: self.server_addr.clone(),
112        };
113        let cancellation_handle = Arc::new(CancellationHandle::default());
114        let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
115
116        self.catalogs
117            .write()
118            .unwrap()
119            .entry(catalog.clone())
120            .or_default()
121            .insert(id, cancellable_process);
122
123        Ticket {
124            catalog,
125            manager: self.clone(),
126            id,
127            cancellation_handle,
128            _slow_query_timer,
129        }
130    }
131
132    /// Generates the next process id.
133    pub fn next_id(&self) -> u32 {
134        self.next_id.fetch_add(1, Ordering::Relaxed)
135    }
136
137    /// De-register a query from process list.
138    pub fn deregister_query(&self, catalog: String, id: ProcessId) {
139        if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
140            let process = o.get_mut().remove(&id);
141            debug!("Deregister process: {:?}", process);
142            if o.get().is_empty() {
143                o.remove();
144            }
145        }
146    }
147
148    /// List local running processes in given catalog.
149    pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
150        let catalogs = self.catalogs.read().unwrap();
151        let result = if let Some(catalog) = catalog {
152            if let Some(catalogs) = catalogs.get(catalog) {
153                catalogs.values().map(|p| p.process.clone()).collect()
154            } else {
155                vec![]
156            }
157        } else {
158            catalogs
159                .values()
160                .flat_map(|v| v.values().map(|p| p.process.clone()))
161                .collect()
162        };
163        Ok(result)
164    }
165
166    pub async fn list_all_processes(
167        &self,
168        catalog: Option<&str>,
169    ) -> error::Result<Vec<ProcessInfo>> {
170        let mut processes = vec![];
171        if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
172            let frontends = remote_frontend_selector
173                .select(|node| node.peer.addr != self.server_addr)
174                .await
175                .context(error::InvokeFrontendSnafu)?;
176            for mut f in frontends {
177                let result = f
178                    .list_process(ListProcessRequest {
179                        catalog: catalog.unwrap_or_default().to_string(),
180                    })
181                    .await
182                    .context(error::InvokeFrontendSnafu);
183                match result {
184                    Ok(resp) => {
185                        processes.extend(resp.processes);
186                    }
187                    Err(e) => {
188                        warn!(e; "Skipping failing node: {:?}", f)
189                    }
190                }
191            }
192        }
193        processes.extend(self.local_processes(catalog)?);
194        Ok(processes)
195    }
196
197    /// Kills query with provided catalog and id.
198    pub async fn kill_process(
199        &self,
200        server_addr: String,
201        catalog: String,
202        id: ProcessId,
203    ) -> error::Result<bool> {
204        if server_addr == self.server_addr {
205            self.kill_local_process(catalog, id).await
206        } else {
207            let mut nodes = self
208                .frontend_selector
209                .as_ref()
210                .context(error::MetaClientMissingSnafu)?
211                .select(|node| node.peer.addr == server_addr)
212                .await
213                .context(error::InvokeFrontendSnafu)?;
214            ensure!(
215                !nodes.is_empty(),
216                error::FrontendNotFoundSnafu { addr: server_addr }
217            );
218
219            let request = KillProcessRequest {
220                server_addr,
221                catalog,
222                process_id: id,
223            };
224            nodes[0]
225                .kill_process(request)
226                .await
227                .context(error::InvokeFrontendSnafu)?;
228            Ok(true)
229        }
230    }
231
232    /// Kills local query with provided catalog and id.
233    pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
234        if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
235            if let Some(process) = catalogs.remove(&id) {
236                process.handle.cancel();
237                info!(
238                    "Killed process, catalog: {}, id: {:?}",
239                    process.process.catalog, process.process.id
240                );
241                PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
242                Ok(true)
243            } else {
244                debug!("Failed to kill process, id not found: {}", id);
245                Ok(false)
246            }
247        } else {
248            debug!("Failed to kill process, catalog not found: {}", catalog);
249            Ok(false)
250        }
251    }
252}
253
254pub struct Ticket {
255    pub(crate) catalog: String,
256    pub(crate) manager: ProcessManagerRef,
257    pub(crate) id: ProcessId,
258    pub cancellation_handle: Arc<CancellationHandle>,
259
260    // Keep the handle of the slow query timer to ensure it will trigger the event recording when dropped.
261    _slow_query_timer: Option<SlowQueryTimer>,
262}
263
264impl Drop for Ticket {
265    fn drop(&mut self) {
266        self.manager
267            .deregister_query(std::mem::take(&mut self.catalog), self.id);
268    }
269}
270
271struct CancellableProcess {
272    handle: Arc<CancellationHandle>,
273    process: ProcessInfo,
274}
275
276impl Drop for CancellableProcess {
277    fn drop(&mut self) {
278        PROCESS_LIST_COUNT
279            .with_label_values(&[&self.process.catalog])
280            .dec();
281    }
282}
283
284impl CancellableProcess {
285    fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
286        PROCESS_LIST_COUNT
287            .with_label_values(&[&process.catalog])
288            .inc();
289        Self { handle, process }
290    }
291}
292
293impl Debug for CancellableProcess {
294    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295        f.debug_struct("CancellableProcess")
296            .field("cancelled", &self.handle.is_cancelled())
297            .field("process", &self.process)
298            .finish()
299    }
300}
301
302/// SlowQueryTimer is used to log slow query when it's dropped.
303/// In drop(), it will check if the query is slow and send the slow query event to the handler.
304pub struct SlowQueryTimer {
305    start: Instant,
306    stmt: QueryStatement,
307    threshold: Duration,
308    sample_ratio: f64,
309    record_type: SlowQueriesRecordType,
310    recorder: EventRecorderRef,
311}
312
313impl SlowQueryTimer {
314    pub fn new(
315        stmt: QueryStatement,
316        threshold: Duration,
317        sample_ratio: f64,
318        record_type: SlowQueriesRecordType,
319        recorder: EventRecorderRef,
320    ) -> Self {
321        Self {
322            start: Instant::now(),
323            stmt,
324            threshold,
325            sample_ratio,
326            record_type,
327            recorder,
328        }
329    }
330}
331
332impl SlowQueryTimer {
333    fn send_slow_query_event(&self, elapsed: Duration) {
334        let mut slow_query_event = SlowQueryEvent {
335            cost: elapsed.as_millis() as u64,
336            threshold: self.threshold.as_millis() as u64,
337            query: "".to_string(),
338
339            // The following fields are only used for PromQL queries.
340            is_promql: false,
341            promql_range: None,
342            promql_step: None,
343            promql_start: None,
344            promql_end: None,
345        };
346
347        match &self.stmt {
348            QueryStatement::Promql(stmt, _alias) => {
349                slow_query_event.is_promql = true;
350                slow_query_event.query = self.stmt.to_string();
351                slow_query_event.promql_step = Some(stmt.interval.as_millis() as u64);
352
353                let start = stmt
354                    .start
355                    .duration_since(UNIX_EPOCH)
356                    .unwrap_or_default()
357                    .as_millis() as i64;
358
359                let end = stmt
360                    .end
361                    .duration_since(UNIX_EPOCH)
362                    .unwrap_or_default()
363                    .as_millis() as i64;
364
365                slow_query_event.promql_range = Some((end - start) as u64);
366                slow_query_event.promql_start = Some(start);
367                slow_query_event.promql_end = Some(end);
368            }
369            QueryStatement::Sql(stmt) => {
370                slow_query_event.query = stmt.to_string();
371            }
372        }
373
374        match self.record_type {
375            // Send the slow query event to the event recorder to persist it as the system table.
376            SlowQueriesRecordType::SystemTable => {
377                self.recorder.record(Box::new(slow_query_event));
378            }
379            // Record the slow query in a specific logs file.
380            SlowQueriesRecordType::Log => {
381                slow!(
382                    cost = slow_query_event.cost,
383                    threshold = slow_query_event.threshold,
384                    query = slow_query_event.query,
385                    is_promql = slow_query_event.is_promql,
386                    promql_range = slow_query_event.promql_range,
387                    promql_step = slow_query_event.promql_step,
388                    promql_start = slow_query_event.promql_start,
389                    promql_end = slow_query_event.promql_end,
390                );
391            }
392        }
393    }
394}
395
396impl Drop for SlowQueryTimer {
397    fn drop(&mut self) {
398        // Calculate the elaspsed duration since the timer is created.
399        let elapsed = self.start.elapsed();
400        if elapsed > self.threshold {
401            // Only capture a portion of slow queries based on sample_ratio.
402            // Generate a random number in [0, 1) and compare it with sample_ratio.
403            if self.sample_ratio >= 1.0 || random::<f64>() <= self.sample_ratio {
404                self.send_slow_query_event(elapsed);
405            }
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use std::sync::Arc;
413
414    use crate::process_manager::ProcessManager;
415
416    #[tokio::test]
417    async fn test_register_query() {
418        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
419        let ticket = process_manager.clone().register_query(
420            "public".to_string(),
421            vec!["test".to_string()],
422            "SELECT * FROM table".to_string(),
423            "".to_string(),
424            None,
425            None,
426        );
427
428        let running_processes = process_manager.local_processes(None).unwrap();
429        assert_eq!(running_processes.len(), 1);
430        assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
431        assert_eq!(running_processes[0].id, ticket.id);
432        assert_eq!(&running_processes[0].query, "SELECT * FROM table");
433
434        drop(ticket);
435        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
436    }
437
438    #[tokio::test]
439    async fn test_register_query_with_custom_id() {
440        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
441        let custom_id = 12345;
442
443        let ticket = process_manager.clone().register_query(
444            "public".to_string(),
445            vec!["test".to_string()],
446            "SELECT * FROM table".to_string(),
447            "client1".to_string(),
448            Some(custom_id),
449            None,
450        );
451
452        assert_eq!(ticket.id, custom_id);
453
454        let running_processes = process_manager.local_processes(None).unwrap();
455        assert_eq!(running_processes.len(), 1);
456        assert_eq!(running_processes[0].id, custom_id);
457        assert_eq!(&running_processes[0].client, "client1");
458    }
459
460    #[tokio::test]
461    async fn test_multiple_queries_same_catalog() {
462        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
463
464        let ticket1 = process_manager.clone().register_query(
465            "public".to_string(),
466            vec!["schema1".to_string()],
467            "SELECT * FROM table1".to_string(),
468            "client1".to_string(),
469            None,
470            None,
471        );
472
473        let ticket2 = process_manager.clone().register_query(
474            "public".to_string(),
475            vec!["schema2".to_string()],
476            "SELECT * FROM table2".to_string(),
477            "client2".to_string(),
478            None,
479            None,
480        );
481
482        let running_processes = process_manager.local_processes(Some("public")).unwrap();
483        assert_eq!(running_processes.len(), 2);
484
485        // Verify both processes are present
486        let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
487        assert!(ids.contains(&ticket1.id));
488        assert!(ids.contains(&ticket2.id));
489    }
490
491    #[tokio::test]
492    async fn test_multiple_catalogs() {
493        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
494
495        let _ticket1 = process_manager.clone().register_query(
496            "catalog1".to_string(),
497            vec!["schema1".to_string()],
498            "SELECT * FROM table1".to_string(),
499            "client1".to_string(),
500            None,
501            None,
502        );
503
504        let _ticket2 = process_manager.clone().register_query(
505            "catalog2".to_string(),
506            vec!["schema2".to_string()],
507            "SELECT * FROM table2".to_string(),
508            "client2".to_string(),
509            None,
510            None,
511        );
512
513        // Test listing processes for specific catalog
514        let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
515        assert_eq!(catalog1_processes.len(), 1);
516        assert_eq!(&catalog1_processes[0].catalog, "catalog1");
517
518        let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
519        assert_eq!(catalog2_processes.len(), 1);
520        assert_eq!(&catalog2_processes[0].catalog, "catalog2");
521
522        // Test listing all processes
523        let all_processes = process_manager.local_processes(None).unwrap();
524        assert_eq!(all_processes.len(), 2);
525    }
526
527    #[tokio::test]
528    async fn test_deregister_query() {
529        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
530
531        let ticket = process_manager.clone().register_query(
532            "public".to_string(),
533            vec!["test".to_string()],
534            "SELECT * FROM table".to_string(),
535            "client1".to_string(),
536            None,
537            None,
538        );
539        assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
540        process_manager.deregister_query("public".to_string(), ticket.id);
541        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
542    }
543
544    #[tokio::test]
545    async fn test_cancellation_handle() {
546        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
547
548        let ticket = process_manager.clone().register_query(
549            "public".to_string(),
550            vec!["test".to_string()],
551            "SELECT * FROM table".to_string(),
552            "client1".to_string(),
553            None,
554            None,
555        );
556
557        assert!(!ticket.cancellation_handle.is_cancelled());
558        ticket.cancellation_handle.cancel();
559        assert!(ticket.cancellation_handle.is_cancelled());
560    }
561
562    #[tokio::test]
563    async fn test_kill_local_process() {
564        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
565
566        let ticket = process_manager.clone().register_query(
567            "public".to_string(),
568            vec!["test".to_string()],
569            "SELECT * FROM table".to_string(),
570            "client1".to_string(),
571            None,
572            None,
573        );
574        assert!(!ticket.cancellation_handle.is_cancelled());
575        let killed = process_manager
576            .kill_process(
577                "127.0.0.1:8000".to_string(),
578                "public".to_string(),
579                ticket.id,
580            )
581            .await
582            .unwrap();
583
584        assert!(killed);
585        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
586    }
587
588    #[tokio::test]
589    async fn test_kill_nonexistent_process() {
590        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
591        let killed = process_manager
592            .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
593            .await
594            .unwrap();
595        assert!(!killed);
596    }
597
598    #[tokio::test]
599    async fn test_kill_process_nonexistent_catalog() {
600        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
601        let killed = process_manager
602            .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
603            .await
604            .unwrap();
605        assert!(!killed);
606    }
607
608    #[tokio::test]
609    async fn test_process_info_fields() {
610        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
611
612        let _ticket = process_manager.clone().register_query(
613            "test_catalog".to_string(),
614            vec!["schema1".to_string(), "schema2".to_string()],
615            "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
616            "test_client".to_string(),
617            Some(42),
618            None,
619        );
620
621        let processes = process_manager.local_processes(None).unwrap();
622        assert_eq!(processes.len(), 1);
623
624        let process = &processes[0];
625        assert_eq!(process.id, 42);
626        assert_eq!(&process.catalog, "test_catalog");
627        assert_eq!(process.schemas, vec!["schema1", "schema2"]);
628        assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
629        assert_eq!(&process.client, "test_client");
630        assert_eq!(&process.frontend, "127.0.0.1:8000");
631        assert!(process.start_timestamp > 0);
632    }
633
634    #[tokio::test]
635    async fn test_ticket_drop_deregisters_process() {
636        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
637
638        {
639            let _ticket = process_manager.clone().register_query(
640                "public".to_string(),
641                vec!["test".to_string()],
642                "SELECT * FROM table".to_string(),
643                "client1".to_string(),
644                None,
645                None,
646            );
647
648            // Process should be registered
649            assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
650        } // ticket goes out of scope here
651
652        // Process should be automatically deregistered
653        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
654    }
655}