Skip to main content

catalog/
process_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::fmt::{Debug, Display, Formatter};
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant, UNIX_EPOCH};
21
22use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
23use common_base::cancellation::CancellationHandle;
24use common_event_recorder::EventRecorderRef;
25use common_frontend::selector::{FrontendSelector, MetaClientSelector};
26use common_frontend::slow_query_event::SlowQueryEvent;
27use common_telemetry::logging::SlowQueriesRecordType;
28use common_telemetry::{debug, info, slow, warn};
29use common_time::util::current_time_millis;
30use meta_client::MetaClientRef;
31use promql_parser::parser::EvalStmt;
32use rand::random;
33use snafu::{OptionExt, ResultExt, ensure};
34use sql::statements::statement::Statement;
35
36use crate::error;
37use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
38
39pub type ProcessId = u32;
40pub type ProcessManagerRef = Arc<ProcessManager>;
41
42/// Query process manager.
43pub struct ProcessManager {
44    /// Local frontend server address,
45    server_addr: String,
46    /// Next process id for local queries.
47    next_id: AtomicU32,
48    /// Running process per catalog.
49    catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
50    /// Frontend selector to locate frontend nodes.
51    frontend_selector: Option<MetaClientSelector>,
52}
53
54/// Represents a parsed query statement, functionally equivalent to [query::parser::QueryStatement].
55/// This enum is defined here to avoid cyclic dependencies with the query parser module.
56#[derive(Debug, Clone)]
57pub enum QueryStatement {
58    Sql(Statement),
59    // The optional string is the alias of the PromQL query.
60    Promql(EvalStmt, Option<String>),
61    /// Logical plan with original query string
62    Plan(String),
63}
64
65impl Display for QueryStatement {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        match self {
68            QueryStatement::Sql(stmt) => write!(f, "{}", stmt),
69            QueryStatement::Promql(eval_stmt, alias) => {
70                if let Some(alias) = alias {
71                    write!(f, "{} AS {}", eval_stmt, alias)
72                } else {
73                    write!(f, "{}", eval_stmt)
74                }
75            }
76            QueryStatement::Plan(query) => write!(f, "{}", query),
77        }
78    }
79}
80
81impl ProcessManager {
82    /// Create a [ProcessManager] instance with server address and kv client.
83    pub fn new(server_addr: String, meta_client: Option<MetaClientRef>) -> Self {
84        let frontend_selector = meta_client.map(MetaClientSelector::new);
85        Self {
86            server_addr,
87            next_id: Default::default(),
88            catalogs: Default::default(),
89            frontend_selector,
90        }
91    }
92}
93
94impl ProcessManager {
95    /// Registers a submitted query. Use the provided id if present.
96    #[must_use]
97    pub fn register_query(
98        self: &Arc<Self>,
99        catalog: String,
100        schemas: Vec<String>,
101        query: String,
102        client: String,
103        query_id: Option<ProcessId>,
104        _slow_query_timer: Option<SlowQueryTimer>,
105    ) -> Ticket {
106        let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
107        let process = ProcessInfo {
108            id,
109            catalog: catalog.clone(),
110            schemas,
111            query,
112            start_timestamp: current_time_millis(),
113            client,
114            frontend: self.server_addr.clone(),
115        };
116        let cancellation_handle = Arc::new(CancellationHandle::default());
117        let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process);
118
119        self.catalogs
120            .write()
121            .unwrap()
122            .entry(catalog.clone())
123            .or_default()
124            .insert(id, cancellable_process);
125
126        Ticket {
127            catalog,
128            manager: self.clone(),
129            id,
130            cancellation_handle,
131            _slow_query_timer,
132        }
133    }
134
135    /// Generates the next process id.
136    pub fn next_id(&self) -> u32 {
137        self.next_id.fetch_add(1, Ordering::Relaxed)
138    }
139
140    /// De-register a query from process list.
141    pub fn deregister_query(&self, catalog: String, id: ProcessId) {
142        if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
143            let process = o.get_mut().remove(&id);
144            debug!("Deregister process: {:?}", process);
145            if o.get().is_empty() {
146                o.remove();
147            }
148        }
149    }
150
151    /// List local running processes in given catalog.
152    pub fn local_processes(&self, catalog: Option<&str>) -> error::Result<Vec<ProcessInfo>> {
153        let catalogs = self.catalogs.read().unwrap();
154        let result = if let Some(catalog) = catalog {
155            if let Some(catalogs) = catalogs.get(catalog) {
156                catalogs.values().map(|p| p.process.clone()).collect()
157            } else {
158                vec![]
159            }
160        } else {
161            catalogs
162                .values()
163                .flat_map(|v| v.values().map(|p| p.process.clone()))
164                .collect()
165        };
166        Ok(result)
167    }
168
169    pub async fn list_all_processes(
170        &self,
171        catalog: Option<&str>,
172    ) -> error::Result<Vec<ProcessInfo>> {
173        let mut processes = vec![];
174        if let Some(remote_frontend_selector) = self.frontend_selector.as_ref() {
175            let frontends = remote_frontend_selector
176                .select(|peer| peer.addr != self.server_addr)
177                .await
178                .context(error::InvokeFrontendSnafu)?;
179            for mut f in frontends {
180                let result = f
181                    .list_process(ListProcessRequest {
182                        catalog: catalog.unwrap_or_default().to_string(),
183                    })
184                    .await
185                    .context(error::InvokeFrontendSnafu);
186                match result {
187                    Ok(resp) => {
188                        processes.extend(resp.processes);
189                    }
190                    Err(e) => {
191                        warn!(e; "Skipping failing node: {:?}", f)
192                    }
193                }
194            }
195        }
196        processes.extend(self.local_processes(catalog)?);
197        Ok(processes)
198    }
199
200    /// Kills query with provided catalog and id.
201    pub async fn kill_process(
202        &self,
203        server_addr: String,
204        catalog: String,
205        id: ProcessId,
206    ) -> error::Result<bool> {
207        if server_addr == self.server_addr {
208            self.kill_local_process(catalog, id).await
209        } else {
210            let mut nodes = self
211                .frontend_selector
212                .as_ref()
213                .context(error::MetaClientMissingSnafu)?
214                .select(|peer| peer.addr == server_addr)
215                .await
216                .context(error::InvokeFrontendSnafu)?;
217            ensure!(
218                !nodes.is_empty(),
219                error::FrontendNotFoundSnafu { addr: server_addr }
220            );
221
222            let request = KillProcessRequest {
223                server_addr,
224                catalog,
225                process_id: id,
226            };
227            nodes[0]
228                .kill_process(request)
229                .await
230                .context(error::InvokeFrontendSnafu)?;
231            Ok(true)
232        }
233    }
234
235    /// Kills local query with provided catalog and id.
236    pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
237        if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
238            if let Some(process) = catalogs.remove(&id) {
239                process.handle.cancel();
240                info!(
241                    "Killed process, catalog: {}, id: {:?}",
242                    process.process.catalog, process.process.id
243                );
244                PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
245                Ok(true)
246            } else {
247                debug!("Failed to kill process, id not found: {}", id);
248                Ok(false)
249            }
250        } else {
251            debug!("Failed to kill process, catalog not found: {}", catalog);
252            Ok(false)
253        }
254    }
255}
256
257pub struct Ticket {
258    pub(crate) catalog: String,
259    pub(crate) manager: ProcessManagerRef,
260    pub(crate) id: ProcessId,
261    pub cancellation_handle: Arc<CancellationHandle>,
262
263    // Keep the handle of the slow query timer to ensure it will trigger the event recording when dropped.
264    _slow_query_timer: Option<SlowQueryTimer>,
265}
266
267impl Drop for Ticket {
268    fn drop(&mut self) {
269        self.manager
270            .deregister_query(std::mem::take(&mut self.catalog), self.id);
271    }
272}
273
274struct CancellableProcess {
275    handle: Arc<CancellationHandle>,
276    process: ProcessInfo,
277}
278
279impl Drop for CancellableProcess {
280    fn drop(&mut self) {
281        PROCESS_LIST_COUNT
282            .with_label_values(&[&self.process.catalog])
283            .dec();
284    }
285}
286
287impl CancellableProcess {
288    fn new(handle: Arc<CancellationHandle>, process: ProcessInfo) -> Self {
289        PROCESS_LIST_COUNT
290            .with_label_values(&[&process.catalog])
291            .inc();
292        Self { handle, process }
293    }
294}
295
296impl Debug for CancellableProcess {
297    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("CancellableProcess")
299            .field("cancelled", &self.handle.is_cancelled())
300            .field("process", &self.process)
301            .finish()
302    }
303}
304
305/// SlowQueryTimer is used to log slow query when it's dropped.
306/// In drop(), it will check if the query is slow and send the slow query event to the handler.
307pub struct SlowQueryTimer {
308    start: Instant,
309    stmt: QueryStatement,
310    threshold: Duration,
311    sample_ratio: f64,
312    record_type: SlowQueriesRecordType,
313    recorder: EventRecorderRef,
314}
315
316impl SlowQueryTimer {
317    pub fn new(
318        stmt: QueryStatement,
319        threshold: Duration,
320        sample_ratio: f64,
321        record_type: SlowQueriesRecordType,
322        recorder: EventRecorderRef,
323    ) -> Self {
324        Self {
325            start: Instant::now(),
326            stmt,
327            threshold,
328            sample_ratio,
329            record_type,
330            recorder,
331        }
332    }
333}
334
335impl SlowQueryTimer {
336    fn send_slow_query_event(&self, elapsed: Duration) {
337        let mut slow_query_event = SlowQueryEvent {
338            cost: elapsed.as_millis() as u64,
339            threshold: self.threshold.as_millis() as u64,
340            query: "".to_string(),
341
342            // The following fields are only used for PromQL queries.
343            is_promql: false,
344            promql_range: None,
345            promql_step: None,
346            promql_start: None,
347            promql_end: None,
348        };
349
350        match &self.stmt {
351            QueryStatement::Promql(stmt, _alias) => {
352                slow_query_event.is_promql = true;
353                slow_query_event.query = self.stmt.to_string();
354                slow_query_event.promql_step = Some(stmt.interval.as_millis() as u64);
355
356                let start = stmt
357                    .start
358                    .duration_since(UNIX_EPOCH)
359                    .unwrap_or_default()
360                    .as_millis() as i64;
361
362                let end = stmt
363                    .end
364                    .duration_since(UNIX_EPOCH)
365                    .unwrap_or_default()
366                    .as_millis() as i64;
367
368                slow_query_event.promql_range = Some((end - start) as u64);
369                slow_query_event.promql_start = Some(start);
370                slow_query_event.promql_end = Some(end);
371            }
372            QueryStatement::Sql(stmt) => {
373                slow_query_event.query = stmt.to_string();
374            }
375            QueryStatement::Plan(query) => {
376                slow_query_event.query = query.clone();
377            }
378        }
379
380        match self.record_type {
381            // Send the slow query event to the event recorder to persist it as the system table.
382            SlowQueriesRecordType::SystemTable => {
383                self.recorder.record(Box::new(slow_query_event));
384            }
385            // Record the slow query in a specific logs file.
386            SlowQueriesRecordType::Log => {
387                slow!(
388                    cost = slow_query_event.cost,
389                    threshold = slow_query_event.threshold,
390                    query = slow_query_event.query,
391                    is_promql = slow_query_event.is_promql,
392                    promql_range = slow_query_event.promql_range,
393                    promql_step = slow_query_event.promql_step,
394                    promql_start = slow_query_event.promql_start,
395                    promql_end = slow_query_event.promql_end,
396                );
397            }
398        }
399    }
400}
401
402impl Drop for SlowQueryTimer {
403    fn drop(&mut self) {
404        // Calculate the elapsed duration since the timer is created.
405        let elapsed = self.start.elapsed();
406        if elapsed > self.threshold {
407            // Only capture a portion of slow queries based on sample_ratio.
408            // Generate a random number in [0, 1) and compare it with sample_ratio.
409            if self.sample_ratio >= 1.0 || random::<f64>() <= self.sample_ratio {
410                self.send_slow_query_event(elapsed);
411            }
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use std::sync::Arc;
419
420    use crate::process_manager::ProcessManager;
421
422    #[tokio::test]
423    async fn test_register_query() {
424        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
425        let ticket = process_manager.clone().register_query(
426            "public".to_string(),
427            vec!["test".to_string()],
428            "SELECT * FROM table".to_string(),
429            "".to_string(),
430            None,
431            None,
432        );
433
434        let running_processes = process_manager.local_processes(None).unwrap();
435        assert_eq!(running_processes.len(), 1);
436        assert_eq!(&running_processes[0].frontend, "127.0.0.1:8000");
437        assert_eq!(running_processes[0].id, ticket.id);
438        assert_eq!(&running_processes[0].query, "SELECT * FROM table");
439
440        drop(ticket);
441        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
442    }
443
444    #[tokio::test]
445    async fn test_register_query_with_custom_id() {
446        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
447        let custom_id = 12345;
448
449        let ticket = process_manager.clone().register_query(
450            "public".to_string(),
451            vec!["test".to_string()],
452            "SELECT * FROM table".to_string(),
453            "client1".to_string(),
454            Some(custom_id),
455            None,
456        );
457
458        assert_eq!(ticket.id, custom_id);
459
460        let running_processes = process_manager.local_processes(None).unwrap();
461        assert_eq!(running_processes.len(), 1);
462        assert_eq!(running_processes[0].id, custom_id);
463        assert_eq!(&running_processes[0].client, "client1");
464    }
465
466    #[tokio::test]
467    async fn test_multiple_queries_same_catalog() {
468        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
469
470        let ticket1 = process_manager.clone().register_query(
471            "public".to_string(),
472            vec!["schema1".to_string()],
473            "SELECT * FROM table1".to_string(),
474            "client1".to_string(),
475            None,
476            None,
477        );
478
479        let ticket2 = process_manager.clone().register_query(
480            "public".to_string(),
481            vec!["schema2".to_string()],
482            "SELECT * FROM table2".to_string(),
483            "client2".to_string(),
484            None,
485            None,
486        );
487
488        let running_processes = process_manager.local_processes(Some("public")).unwrap();
489        assert_eq!(running_processes.len(), 2);
490
491        // Verify both processes are present
492        let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
493        assert!(ids.contains(&ticket1.id));
494        assert!(ids.contains(&ticket2.id));
495    }
496
497    #[tokio::test]
498    async fn test_multiple_catalogs() {
499        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
500
501        let _ticket1 = process_manager.clone().register_query(
502            "catalog1".to_string(),
503            vec!["schema1".to_string()],
504            "SELECT * FROM table1".to_string(),
505            "client1".to_string(),
506            None,
507            None,
508        );
509
510        let _ticket2 = process_manager.clone().register_query(
511            "catalog2".to_string(),
512            vec!["schema2".to_string()],
513            "SELECT * FROM table2".to_string(),
514            "client2".to_string(),
515            None,
516            None,
517        );
518
519        // Test listing processes for specific catalog
520        let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap();
521        assert_eq!(catalog1_processes.len(), 1);
522        assert_eq!(&catalog1_processes[0].catalog, "catalog1");
523
524        let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap();
525        assert_eq!(catalog2_processes.len(), 1);
526        assert_eq!(&catalog2_processes[0].catalog, "catalog2");
527
528        // Test listing all processes
529        let all_processes = process_manager.local_processes(None).unwrap();
530        assert_eq!(all_processes.len(), 2);
531    }
532
533    #[tokio::test]
534    async fn test_deregister_query() {
535        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
536
537        let ticket = process_manager.clone().register_query(
538            "public".to_string(),
539            vec!["test".to_string()],
540            "SELECT * FROM table".to_string(),
541            "client1".to_string(),
542            None,
543            None,
544        );
545        assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
546        process_manager.deregister_query("public".to_string(), ticket.id);
547        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
548    }
549
550    #[tokio::test]
551    async fn test_cancellation_handle() {
552        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
553
554        let ticket = process_manager.clone().register_query(
555            "public".to_string(),
556            vec!["test".to_string()],
557            "SELECT * FROM table".to_string(),
558            "client1".to_string(),
559            None,
560            None,
561        );
562
563        assert!(!ticket.cancellation_handle.is_cancelled());
564        ticket.cancellation_handle.cancel();
565        assert!(ticket.cancellation_handle.is_cancelled());
566    }
567
568    #[tokio::test]
569    async fn test_kill_local_process() {
570        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
571
572        let ticket = process_manager.clone().register_query(
573            "public".to_string(),
574            vec!["test".to_string()],
575            "SELECT * FROM table".to_string(),
576            "client1".to_string(),
577            None,
578            None,
579        );
580        assert!(!ticket.cancellation_handle.is_cancelled());
581        let killed = process_manager
582            .kill_process(
583                "127.0.0.1:8000".to_string(),
584                "public".to_string(),
585                ticket.id,
586            )
587            .await
588            .unwrap();
589
590        assert!(killed);
591        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
592    }
593
594    #[tokio::test]
595    async fn test_kill_nonexistent_process() {
596        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
597        let killed = process_manager
598            .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999)
599            .await
600            .unwrap();
601        assert!(!killed);
602    }
603
604    #[tokio::test]
605    async fn test_kill_process_nonexistent_catalog() {
606        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
607        let killed = process_manager
608            .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1)
609            .await
610            .unwrap();
611        assert!(!killed);
612    }
613
614    #[tokio::test]
615    async fn test_process_info_fields() {
616        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
617
618        let _ticket = process_manager.clone().register_query(
619            "test_catalog".to_string(),
620            vec!["schema1".to_string(), "schema2".to_string()],
621            "SELECT COUNT(*) FROM users WHERE age > 18".to_string(),
622            "test_client".to_string(),
623            Some(42),
624            None,
625        );
626
627        let processes = process_manager.local_processes(None).unwrap();
628        assert_eq!(processes.len(), 1);
629
630        let process = &processes[0];
631        assert_eq!(process.id, 42);
632        assert_eq!(&process.catalog, "test_catalog");
633        assert_eq!(process.schemas, vec!["schema1", "schema2"]);
634        assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18");
635        assert_eq!(&process.client, "test_client");
636        assert_eq!(&process.frontend, "127.0.0.1:8000");
637        assert!(process.start_timestamp > 0);
638    }
639
640    #[tokio::test]
641    async fn test_ticket_drop_deregisters_process() {
642        let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None));
643
644        {
645            let _ticket = process_manager.clone().register_query(
646                "public".to_string(),
647                vec!["test".to_string()],
648                "SELECT * FROM table".to_string(),
649                "client1".to_string(),
650                None,
651                None,
652            );
653
654            // Process should be registered
655            assert_eq!(process_manager.local_processes(None).unwrap().len(), 1);
656        } // ticket goes out of scope here
657
658        // Process should be automatically deregistered
659        assert_eq!(process_manager.local_processes(None).unwrap().len(), 0);
660    }
661}