catalog/
process_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::hash_map::Entry;
16use std::collections::HashMap;
17use std::fmt::{Debug, Display, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant, UNIX_EPOCH};
21
22use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
23use common_base::cancellation::CancellationHandle;
24use common_event_recorder::EventRecorderRef;
25use common_frontend::selector::{FrontendSelector, MetaClientSelector};
26use common_frontend::slow_query_event::SlowQueryEvent;
27use common_telemetry::logging::SlowQueriesRecordType;
28use common_telemetry::{debug, info, slow, warn};
29use common_time::util::current_time_millis;
30use meta_client::MetaClientRef;
31use promql_parser::parser::EvalStmt;
32use rand::random;
33use snafu::{ensure, OptionExt, ResultExt};
34use sql::statements::statement::Statement;
35
36use crate::error;
37use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
38
39pub type ProcessId = u32;
40pub type ProcessManagerRef = Arc<ProcessManager>;
41
42/// Query process manager.
43pub struct ProcessManager {
44    /// Local frontend server address,
45    server_addr: String,
46    /// Next process id for local queries.
47    next_id: AtomicU32,
48    /// Running process per catalog.
49    catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
50    /// Frontend selector to locate frontend nodes.
51    frontend_selector: Option<MetaClientSelector>,
52}
53
54/// Represents a parsed query statement, functionally equivalent to [query::parser::QueryStatement].
55/// This enum is defined here to avoid cyclic dependencies with the query parser module.
56#[derive(Debug, Clone)]
57pub enum QueryStatement {
58    Sql(Statement),
59    Promql(EvalStmt),
60}
61
62impl Display for QueryStatement {
63    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64        match self {
65            QueryStatement::Sql(stmt) => write!(f, "{}", stmt),
66            QueryStatement::Promql(eval_stmt) => write!(f, "{}", eval_stmt),
67        }
68    }
69}
70
71impl ProcessManager {
72    /// Create a [ProcessManager] instance with server address and kv client.
73    pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
74        let frontend_selector = meta_client.map(MetaClientSelector::new);
75        Self {
76            server_addr,
77            next_id: Default::default(),
78            catalogs: Default::default(),
79            frontend_selector,
80        }
81    }
82}
83
84impl ProcessManager {
85    /// Registers a submitted query. Use the provided id if present.
86    #[must_use]
87    pub fn register_query(
88        self: &Arc<Self>,
89        catalog: String,
90        schemas: Vec<String>,
91        query: String,
92        client: String,
93        query_id: Option<ProcessId>,
94        _slow_query_timer: Option<SlowQueryTimer>,
95    ) -> Ticket {
96        let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
97        let process = ProcessInfo {
98            id,
99            catalog: catalog.clone(),
100            schemas,
101            query,
102            start_timestamp: current_time_millis(),
103            client,
104            frontend: self.server_addr.clone(),
105        };
106        let cancellation_handle = Arc::new(CancellationHandle::default());
107        let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
108
109        self.catalogs
110            .write()
111            .unwrap()
112            .entry(catalog.clone())
113            .or_default()
114            .insert(id, cancellable_process);
115
116        Ticket {
117            catalog,
118            manager: self.clone(),
119            id,
120            cancellation_handle,
121            _slow_query_timer,
122        }
123    }
124
125    /// Generates the next process id.
126    pub fn next_id(&self) -> u32 {
127        self.next_id.fetch_add(1, Ordering::Relaxed)
128    }
129
130    /// De-register a query from process list.
131    pub fn deregister_query(&self, catalog: String, id: ProcessId) {
132        if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
133            let process = o.get_mut().remove(&id);
134            debug!("Deregister process: {:?}", process);
135            if o.get().is_empty() {
136                o.remove();
137            }
138        }
139    }
140
141    /// List local running processes in given catalog.
142    pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
143        let catalogs = self.catalogs.read().unwrap();
144        let result = if let Some(catalog) = catalog {
145            if let Some(catalogs) = catalogs.get(catalog) {
146                catalogs.values().map(|p| p.process.clone()).collect()
147            } else {
148                vec![]
149            }
150        } else {
151            catalogs
152                .values()
153                .flat_map(|v| v.values().map(|p| p.process.clone()))
154                .collect()
155        };
156        Ok(result)
157    }
158
159    pub async fn list_all_processes(
160        &self,
161        catalog: Option<&str>,
162    ) -> error::Result<Vec<ProcessInfo>> {
163        let mut processes = vec![];
164        if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
165            let frontends = remote_frontend_selector
166                .select(|node| node.peer.addr != self.server_addr)
167                .await
168                .context(error::InvokeFrontendSnafu)?;
169            for mut f in frontends {
170                let result = f
171                    .list_process(ListProcessRequest {
172                        catalog: catalog.unwrap_or_default().to_string(),
173                    })
174                    .await
175                    .context(error::InvokeFrontendSnafu);
176                match result {
177                    Ok(resp) => {
178                        processes.extend(resp.processes);
179                    }
180                    Err(e) => {
181                        warn!(e; "Skipping failing node: {:?}", f)
182                    }
183                }
184            }
185        }
186        processes.extend(self.local_processes(catalog)?);
187        Ok(processes)
188    }
189
190    /// Kills query with provided catalog and id.
191    pub async fn kill_process(
192        &self,
193        server_addr: String,
194        catalog: String,
195        id: ProcessId,
196    ) -> error::Result<bool> {
197        if server_addr == self.server_addr {
198            self.kill_local_process(catalog, id).await
199        } else {
200            let mut nodes = self
201                .frontend_selector
202                .as_ref()
203                .context(error::MetaClientMissingSnafu)?
204                .select(|node| node.peer.addr == server_addr)
205                .await
206                .context(error::InvokeFrontendSnafu)?;
207            ensure!(
208                !nodes.is_empty(),
209                error::FrontendNotFoundSnafu { addr: server_addr }
210            );
211
212            let request = KillProcessRequest {
213                server_addr,
214                catalog,
215                process_id: id,
216            };
217            nodes[0]
218                .kill_process(request)
219                .await
220                .context(error::InvokeFrontendSnafu)?;
221            Ok(true)
222        }
223    }
224
225    /// Kills local query with provided catalog and id.
226    pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
227        if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
228            if let Some(process) = catalogs.remove(&id) {
229                process.handle.cancel();
230                info!(
231                    "Killed process, catalog: {}, id: {:?}",
232                    process.process.catalog, process.process.id
233                );
234                PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
235                Ok(true)
236            } else {
237                debug!("Failed to kill process, id not found: {}", id);
238                Ok(false)
239            }
240        } else {
241            debug!("Failed to kill process, catalog not found: {}", catalog);
242            Ok(false)
243        }
244    }
245}
246
247pub struct Ticket {
248    pub(crate) catalog: String,
249    pub(crate) manager: ProcessManagerRef,
250    pub(crate) id: ProcessId,
251    pub cancellation_handle: Arc<CancellationHandle>,
252
253    // Keep the handle of the slow query timer to ensure it will trigger the event recording when dropped.
254    _slow_query_timer: Option<SlowQueryTimer>,
255}
256
257impl Drop for Ticket {
258    fn drop(&mut self) {
259        self.manager
260            .deregister_query(std::mem::take(&mut self.catalog), self.id);
261    }
262}
263
264struct CancellableProcess {
265    handle: Arc<CancellationHandle>,
266    process: ProcessInfo,
267}
268
269impl Drop for CancellableProcess {
270    fn drop(&mut self) {
271        PROCESS_LIST_COUNT
272            .with_label_values(&[&self.process.catalog])
273            .dec();
274    }
275}
276
277impl CancellableProcess {
278    fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
279        PROCESS_LIST_COUNT
280            .with_label_values(&[&process.catalog])
281            .inc();
282        Self { handle, process }
283    }
284}
285
286impl Debug for CancellableProcess {
287    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("CancellableProcess")
289            .field("cancelled", &self.handle.is_cancelled())
290            .field("process", &self.process)
291            .finish()
292    }
293}
294
295/// SlowQueryTimer is used to log slow query when it's dropped.
296/// In drop(), it will check if the query is slow and send the slow query event to the handler.
297pub struct SlowQueryTimer {
298    start: Instant,
299    stmt: QueryStatement,
300    threshold: Duration,
301    sample_ratio: f64,
302    record_type: SlowQueriesRecordType,
303    recorder: EventRecorderRef,
304}
305
306impl SlowQueryTimer {
307    pub fn new(
308        stmt: QueryStatement,
309        threshold: Duration,
310        sample_ratio: f64,
311        record_type: SlowQueriesRecordType,
312        recorder: EventRecorderRef,
313    ) -> Self {
314        Self {
315            start: Instant::now(),
316            stmt,
317            threshold,
318            sample_ratio,
319            record_type,
320            recorder,
321        }
322    }
323}
324
325impl SlowQueryTimer {
326    fn send_slow_query_event(&self, elapsed: Duration) {
327        let mut slow_query_event = SlowQueryEvent {
328            cost: elapsed.as_millis() as u64,
329            threshold: self.threshold.as_millis() as u64,
330            query: "".to_string(),
331
332            // The following fields are only used for PromQL queries.
333            is_promql: false,
334            promql_range: None,
335            promql_step: None,
336            promql_start: None,
337            promql_end: None,
338        };
339
340        match &self.stmt {
341            QueryStatement::Promql(stmt) => {
342                slow_query_event.is_promql = true;
343                slow_query_event.query = stmt.expr.to_string();
344                slow_query_event.promql_step = Some(stmt.interval.as_millis() as u64);
345
346                let start = stmt
347                    .start
348                    .duration_since(UNIX_EPOCH)
349                    .unwrap_or_default()
350                    .as_millis() as i64;
351
352                let end = stmt
353                    .end
354                    .duration_since(UNIX_EPOCH)
355                    .unwrap_or_default()
356                    .as_millis() as i64;
357
358                slow_query_event.promql_range = Some((end - start) as u64);
359                slow_query_event.promql_start = Some(start);
360                slow_query_event.promql_end = Some(end);
361            }
362            QueryStatement::Sql(stmt) => {
363                slow_query_event.query = stmt.to_string();
364            }
365        }
366
367        match self.record_type {
368            // Send the slow query event to the event recorder to persist it as the system table.
369            SlowQueriesRecordType::SystemTable => {
370                self.recorder.record(Box::new(slow_query_event));
371            }
372            // Record the slow query in a specific logs file.
373            SlowQueriesRecordType::Log => {
374                slow!(
375                    cost = slow_query_event.cost,
376                    threshold = slow_query_event.threshold,
377                    query = slow_query_event.query,
378                    is_promql = slow_query_event.is_promql,
379                    promql_range = slow_query_event.promql_range,
380                    promql_step = slow_query_event.promql_step,
381                    promql_start = slow_query_event.promql_start,
382                    promql_end = slow_query_event.promql_end,
383                );
384            }
385        }
386    }
387}
388
389impl Drop for SlowQueryTimer {
390    fn drop(&mut self) {
391        // Calculate the elaspsed duration since the timer is created.
392        let elapsed = self.start.elapsed();
393        if elapsed > self.threshold {
394            // Only capture a portion of slow queries based on sample_ratio.
395            // Generate a random number in [0, 1) and compare it with sample_ratio.
396            if self.sample_ratio >= 1.0 || random::<f64>() <= self.sample_ratio {
397                self.send_slow_query_event(elapsed);
398            }
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use std::sync::Arc;
406
407    use crate::process_manager::ProcessManager;
408
409    #[tokio::test]
410    async fn test_register_query() {
411        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
412        let ticket = process_manager.clone().register_query(
413            "public".to_string(),
414            vec!["test".to_string()],
415            "SELECT * FROM table".to_string(),
416            "".to_string(),
417            None,
418            None,
419        );
420
421        let running_processes = process_manager.local_processes(None).unwrap();
422        assert_eq!(running_processes.len(), 1);
423        assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
424        assert_eq!(running_processes[0].id, ticket.id);
425        assert_eq!(&running_processes[0].query, "SELECT * FROM table");
426
427        drop(ticket);
428        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
429    }
430
431    #[tokio::test]
432    async fn test_register_query_with_custom_id() {
433        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
434        let custom_id = 12345;
435
436        let ticket = process_manager.clone().register_query(
437            "public".to_string(),
438            vec!["test".to_string()],
439            "SELECT * FROM table".to_string(),
440            "client1".to_string(),
441            Some(custom_id),
442            None,
443        );
444
445        assert_eq!(ticket.id, custom_id);
446
447        let running_processes = process_manager.local_processes(None).unwrap();
448        assert_eq!(running_processes.len(), 1);
449        assert_eq!(running_processes[0].id, custom_id);
450        assert_eq!(&running_processes[0].client, "client1");
451    }
452
453    #[tokio::test]
454    async fn test_multiple_queries_same_catalog() {
455        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
456
457        let ticket1 = process_manager.clone().register_query(
458            "public".to_string(),
459            vec!["schema1".to_string()],
460            "SELECT * FROM table1".to_string(),
461            "client1".to_string(),
462            None,
463            None,
464        );
465
466        let ticket2 = process_manager.clone().register_query(
467            "public".to_string(),
468            vec!["schema2".to_string()],
469            "SELECT * FROM table2".to_string(),
470            "client2".to_string(),
471            None,
472            None,
473        );
474
475        let running_processes = process_manager.local_processes(Some("public")).unwrap();
476        assert_eq!(running_processes.len(), 2);
477
478        // Verify both processes are present
479        let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
480        assert!(ids.contains(&ticket1.id));
481        assert!(ids.contains(&ticket2.id));
482    }
483
484    #[tokio::test]
485    async fn test_multiple_catalogs() {
486        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
487
488        let _ticket1 = process_manager.clone().register_query(
489            "catalog1".to_string(),
490            vec!["schema1".to_string()],
491            "SELECT * FROM table1".to_string(),
492            "client1".to_string(),
493            None,
494            None,
495        );
496
497        let _ticket2 = process_manager.clone().register_query(
498            "catalog2".to_string(),
499            vec!["schema2".to_string()],
500            "SELECT * FROM table2".to_string(),
501            "client2".to_string(),
502            None,
503            None,
504        );
505
506        // Test listing processes for specific catalog
507        let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
508        assert_eq!(catalog1_processes.len(), 1);
509        assert_eq!(&catalog1_processes[0].catalog, "catalog1");
510
511        let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
512        assert_eq!(catalog2_processes.len(), 1);
513        assert_eq!(&catalog2_processes[0].catalog, "catalog2");
514
515        // Test listing all processes
516        let all_processes = process_manager.local_processes(None).unwrap();
517        assert_eq!(all_processes.len(), 2);
518    }
519
520    #[tokio::test]
521    async fn test_deregister_query() {
522        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
523
524        let ticket = process_manager.clone().register_query(
525            "public".to_string(),
526            vec!["test".to_string()],
527            "SELECT * FROM table".to_string(),
528            "client1".to_string(),
529            None,
530            None,
531        );
532        assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
533        process_manager.deregister_query("public".to_string(), ticket.id);
534        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
535    }
536
537    #[tokio::test]
538    async fn test_cancellation_handle() {
539        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
540
541        let ticket = process_manager.clone().register_query(
542            "public".to_string(),
543            vec!["test".to_string()],
544            "SELECT * FROM table".to_string(),
545            "client1".to_string(),
546            None,
547            None,
548        );
549
550        assert!(!ticket.cancellation_handle.is_cancelled());
551        ticket.cancellation_handle.cancel();
552        assert!(ticket.cancellation_handle.is_cancelled());
553    }
554
555    #[tokio::test]
556    async fn test_kill_local_process() {
557        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
558
559        let ticket = process_manager.clone().register_query(
560            "public".to_string(),
561            vec!["test".to_string()],
562            "SELECT * FROM table".to_string(),
563            "client1".to_string(),
564            None,
565            None,
566        );
567        assert!(!ticket.cancellation_handle.is_cancelled());
568        let killed = process_manager
569            .kill_process(
570                "127.0.0.1:8000".to_string(),
571                "public".to_string(),
572                ticket.id,
573            )
574            .await
575            .unwrap();
576
577        assert!(killed);
578        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
579    }
580
581    #[tokio::test]
582    async fn test_kill_nonexistent_process() {
583        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
584        let killed = process_manager
585            .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
586            .await
587            .unwrap();
588        assert!(!killed);
589    }
590
591    #[tokio::test]
592    async fn test_kill_process_nonexistent_catalog() {
593        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
594        let killed = process_manager
595            .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
596            .await
597            .unwrap();
598        assert!(!killed);
599    }
600
601    #[tokio::test]
602    async fn test_process_info_fields() {
603        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
604
605        let _ticket = process_manager.clone().register_query(
606            "test_catalog".to_string(),
607            vec!["schema1".to_string(), "schema2".to_string()],
608            "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
609            "test_client".to_string(),
610            Some(42),
611            None,
612        );
613
614        let processes = process_manager.local_processes(None).unwrap();
615        assert_eq!(processes.len(), 1);
616
617        let process = &processes[0];
618        assert_eq!(process.id, 42);
619        assert_eq!(&process.catalog, "test_catalog");
620        assert_eq!(process.schemas, vec!["schema1", "schema2"]);
621        assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
622        assert_eq!(&process.client, "test_client");
623        assert_eq!(&process.frontend, "127.0.0.1:8000");
624        assert!(process.start_timestamp > 0);
625    }
626
627    #[tokio::test]
628    async fn test_ticket_drop_deregisters_process() {
629        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
630
631        {
632            let _ticket = process_manager.clone().register_query(
633                "public".to_string(),
634                vec!["test".to_string()],
635                "SELECT * FROM table".to_string(),
636                "client1".to_string(),
637                None,
638                None,
639            );
640
641            // Process should be registered
642            assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
643        } // ticket goes out of scope here
644
645        // Process should be automatically deregistered
646        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
647    }
648}